From 8fcd52febbc63eb6b1966b20a1b0268523dd68f7 Mon Sep 17 00:00:00 2001 From: "Thuan H. Nguyen" <32274287+thuanz123@users.noreply.github.com> Date: Tue, 14 Nov 2023 01:01:15 +0700 Subject: [PATCH] Correct code for distributed training of RealFill (#5740) Correct code for distributed training --- examples/research_projects/realfill/train_realfill.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/research_projects/realfill/train_realfill.py b/examples/research_projects/realfill/train_realfill.py index 1549d81305..e251d8d176 100644 --- a/examples/research_projects/realfill/train_realfill.py +++ b/examples/research_projects/realfill/train_realfill.py @@ -639,7 +639,7 @@ def main(args): for model in models: sub_dir = ( "unet" - if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet.base_model.model))) + if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet).base_model.model)) else "text_encoder" ) model.save_pretrained(os.path.join(output_dir, sub_dir)) @@ -654,12 +654,12 @@ def main(args): sub_dir = ( "unet" - if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet.base_model.model))) + if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet).base_model.model)) else "text_encoder" ) model_cls = ( UNet2DConditionModel - if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet.base_model.model))) + if isinstance(model.base_model.model, type(accelerator.unwrap_model(unet).base_model.model)) else CLIPTextModel ) @@ -937,8 +937,8 @@ def main(args): if accelerator.is_main_process: pipeline = StableDiffusionInpaintPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet.merge_and_unload(), keep_fp32_wrapper=True), - text_encoder=accelerator.unwrap_model(text_encoder.merge_and_unload(), keep_fp32_wrapper=True), + unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True).merge_and_unload(), + text_encoder=accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True).merge_and_unload(), revision=args.revision, )