1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Correct code for distributed training of RealFill (#5740)

Correct code for distributed training
This commit is contained in:
Thuan H. Nguyen
2023-11-14 01:01:15 +07:00
committed by GitHub
parent 0488810f61
commit 8fcd52febb

View File

@@ -639,7 +639,7 @@ def main(args):
for model in models:
sub_dir = (
"unet"
if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet.base_model.model)))
if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet).base_model.model))
else "text_encoder"
)
model.save_pretrained(os.path.join(output_dir, sub_dir))
@@ -654,12 +654,12 @@ def main(args):
sub_dir = (
"unet"
if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet.base_model.model)))
if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet).base_model.model))
else "text_encoder"
)
model_cls = (
UNet2DConditionModel
if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet.base_model.model)))
if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet).base_model.model))
else CLIPTextModel
)
@@ -937,8 +937,8 @@ def main(args):
if accelerator.is_main_process:
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet.merge_and_unload(), keep_fp32_wrapper=True),
text_encoder=accelerator.unwrap_model(text_encoder.merge_and_unload(), keep_fp32_wrapper=True),
unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True).merge_and_unload(),
text_encoder=accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True).merge_and_unload(),
revision=args.revision,
)