From 3deed729e677a011c1a2552faccce3cbb9303626 Mon Sep 17 00:00:00 2001 From: Boseong Jeon Date: Fri, 1 Nov 2024 13:46:05 +0900 Subject: [PATCH] Handling mixed precision for dreambooth flux lora training (#9565) Handling mixed precision and add unwarp Co-authored-by: Sayak Paul Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> --- examples/dreambooth/train_dreambooth_lora_flux.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index a0a197b1b2..e214859525 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -177,7 +177,7 @@ def log_validation( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) + pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) # run inference @@ -1706,7 +1706,7 @@ def main(args): ) # handle guidance - if transformer.config.guidance_embeds: + if accelerator.unwrap_model(transformer).config.guidance_embeds: guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = guidance.expand(model_input.shape[0]) else: @@ -1819,6 +1819,8 @@ def main(args): # create pipeline if not args.train_text_encoder: text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) + text_encoder_one.to(weight_dtype) + text_encoder_two.to(weight_dtype) pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae,