From 4191fc30f4d0c133ed34cc3664636ca9426bd9f8 Mon Sep 17 00:00:00 2001 From: Pham Hong Vinh Date: Fri, 12 Jan 2024 12:32:33 +0700 Subject: [PATCH] change unwrap call --- examples/dreambooth/train_dreambooth.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 00088767a4..17b13db22a 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -939,7 +939,7 @@ def main(args): def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: for model in models: - sub_dir = "unet" if isinstance(model, type(unwrap_model(accelerator, unet))) else "text_encoder" + sub_dir = "unet" if isinstance(model, type(unwrap_model(unet))) else "text_encoder" model.save_pretrained(os.path.join(output_dir, sub_dir)) # make sure to pop weight so that corresponding model is not saved again @@ -950,7 +950,7 @@ def main(args): # pop models so that they are not loaded again model = models.pop() - if isinstance(model, type(unwrap_model(accelerator, text_encoder))): + if isinstance(model, type(unwrap_model(text_encoder))): # load transformers style into model load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder") model.config = load_model.config @@ -995,14 +995,14 @@ def main(args): " doing mixed precision training. copy of the weights should still be float32." ) - if unwrap_model(accelerator, unet).dtype != torch.float32: + if unwrap_model(unet).dtype != torch.float32: raise ValueError( - f"Unet loaded as datatype {unwrap_model(accelerator, unet).dtype}. {low_precision_error_string}" + f"Unet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}" ) - if args.train_text_encoder and unwrap_model(accelerator, text_encoder).dtype != torch.float32: + if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32: raise ValueError( - f"Text encoder loaded as datatype {unwrap_model(accelerator, text_encoder).dtype}." + f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}" ) @@ -1250,7 +1250,7 @@ def main(args): text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, ) - if unwrap_model(accelerator, unet).config.in_channels == channels * 2: + if unwrap_model(unet).config.in_channels == channels * 2: noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) if args.class_labels_conditioning == "timesteps": @@ -1379,14 +1379,14 @@ def main(args): pipeline_args = {} if text_encoder is not None: - pipeline_args["text_encoder"] = unwrap_model(accelerator, text_encoder) + pipeline_args["text_encoder"] = unwrap_model(text_encoder) if args.skip_save_text_encoder: pipeline_args["text_encoder"] = None pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=unwrap_model(accelerator, unet), + unet=unwrap_model(unet), revision=args.revision, variant=args.variant, **pipeline_args,