mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Correct code for distributed training of RealFill (#5740)
Correct code for distributed training
This commit is contained in:
@@ -639,7 +639,7 @@ def main(args):
|
||||
for model in models:
|
||||
sub_dir = (
|
||||
"unet"
|
||||
if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet.base_model.model)))
|
||||
if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet).base_model.model))
|
||||
else "text_encoder"
|
||||
)
|
||||
model.save_pretrained(os.path.join(output_dir, sub_dir))
|
||||
@@ -654,12 +654,12 @@ def main(args):
|
||||
|
||||
sub_dir = (
|
||||
"unet"
|
||||
if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet.base_model.model)))
|
||||
if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet).base_model.model))
|
||||
else "text_encoder"
|
||||
)
|
||||
model_cls = (
|
||||
UNet2DConditionModel
|
||||
if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet.base_model.model)))
|
||||
if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet).base_model.model))
|
||||
else CLIPTextModel
|
||||
)
|
||||
|
||||
@@ -937,8 +937,8 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=accelerator.unwrap_model(unet.merge_and_unload(), keep_fp32_wrapper=True),
|
||||
text_encoder=accelerator.unwrap_model(text_encoder.merge_and_unload(), keep_fp32_wrapper=True),
|
||||
unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True).merge_and_unload(),
|
||||
text_encoder=accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True).merge_and_unload(),
|
||||
revision=args.revision,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user