From 29dfe22a8e6f1ea1e1f6cd4fbb8381f08064091e Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Thu, 14 Dec 2023 11:45:33 -0600 Subject: [PATCH] [advanced dreambooth lora sdxl training script] load pipeline for inference only if validation prompt is used (#6171) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * load pipeline for inference only if validation prompt is used * move things outside * load pipeline for inference only if validation prompt is used * fix readme when validation prompt is used --------- Co-authored-by: linoytsaban Co-authored-by: apolinário --- .../train_dreambooth_lora_sdxl_advanced.py | 77 ++++++++++--------- 1 file changed, 39 insertions(+), 38 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index a46a1afcc1..ad37363b7d 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -112,7 +112,7 @@ def save_model_card( repo_folder=None, vae_path=None, ): - img_str = "widget:\n" if images else "" + img_str = "widget:\n" for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) img_str += f""" @@ -121,6 +121,10 @@ def save_model_card( url: "image_{i}.png" """ + if not images: + img_str += f""" + - text: '{instance_prompt}' + """ trigger_str = f"You should use {instance_prompt} to trigger the image generation." diffusers_imports_pivotal = "" @@ -157,8 +161,6 @@ tags: base_model: {base_model} instance_prompt: {instance_prompt} license: openrail++ -widget: - - text: '{validation_prompt if validation_prompt else instance_prompt}' --- """ @@ -2010,43 +2012,42 @@ def main(args): text_encoder_lora_layers=text_encoder_lora_layers, text_encoder_2_lora_layers=text_encoder_2_lora_layers, ) - - # Final inference - # Load previous pipeline - vae = AutoencoderKL.from_pretrained( - vae_path, - subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, - ) - pipeline = StableDiffusionXLPipeline.from_pretrained( - args.pretrained_model_name_or_path, - vae=vae, - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, - ) - - # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it - scheduler_args = {} - - if "variance_type" in pipeline.scheduler.config: - variance_type = pipeline.scheduler.config.variance_type - - if variance_type in ["learned", "learned_range"]: - variance_type = "fixed_small" - - scheduler_args["variance_type"] = variance_type - - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) - - # load attention processors - pipeline.load_lora_weights(args.output_dir) - - # run inference images = [] if args.validation_prompt and args.num_validation_images > 0: + # Final inference + # Load previous pipeline + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference pipeline = pipeline.to(accelerator.device) generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None images = [