diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index a43bc6f977..00088767a4 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -108,12 +108,6 @@ DreamBooth for the text encoder was enabled: {train_text_encoder}. f.write(yaml + model_card) -def unwrap_model(accelerator, model): - model = accelerator.unwrap_model(model) - model = model._orig_mod if is_compiled_module(model) else model - return model - - def log_validation( text_encoder, tokenizer, @@ -136,15 +130,12 @@ def log_validation( if vae is not None: pipeline_args["vae"] = vae - if text_encoder is not None: - text_encoder = unwrap_model(accelerator, text_encoder) - # create pipeline (note: unet and vae are loaded again in float32) pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, tokenizer=tokenizer, text_encoder=text_encoder, - unet=unwrap_model(accelerator, unet), + unet=unet, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, @@ -939,6 +930,11 @@ def main(args): args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: @@ -1358,9 +1354,9 @@ def main(args): if args.validation_prompt is not None and global_step % args.validation_steps == 0: images = log_validation( - text_encoder, + unwrap_model(text_encoder) if text_encoder is not None else text_encoder, tokenizer, - unet, + unwrap_model(unet), vae, args, accelerator,