From 006d09275122d6c38ecd6e22dfe23630f44095f5 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Tue, 2 Sep 2025 11:30:33 +0300 Subject: [PATCH] [Flux LoRA] fix for prior preservation and mixed precision sampling, follow up on #11873 (#12264) * propagate fixes from https://github.com/huggingface/diffusers/pull/11873/ to flux script * propagate fixes from https://github.com/huggingface/diffusers/pull/11873/ to flux script * propagate fixes from https://github.com/huggingface/diffusers/pull/11873/ to flux script * Apply style fixes --------- Co-authored-by: github-actions[bot] --- .../train_dreambooth_lora_flux_advanced.py | 4 +++- examples/dreambooth/train_dreambooth_lora_flux.py | 11 ++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 951b989d7a..a46490e8b3 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -1399,6 +1399,7 @@ def main(args): torch_dtype = torch.float16 elif args.prior_generation_precision == "bf16": torch_dtype = torch.bfloat16 + pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, @@ -1419,7 +1420,8 @@ def main(args): for example in tqdm( sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): - images = pipeline(example["prompt"]).images + with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype): + images = pipeline(prompt=example["prompt"]).images for i, image in enumerate(images): hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 2353625c38..bd3a974a17 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1131,6 +1131,7 @@ def main(args): torch_dtype = torch.float16 elif args.prior_generation_precision == "bf16": torch_dtype = torch.bfloat16 + pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, torch_dtype=torch_dtype, @@ -1151,7 +1152,8 @@ def main(args): for example in tqdm( sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): - images = pipeline(example["prompt"]).images + with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype): + images = pipeline(prompt=example["prompt"]).images for i, image in enumerate(images): hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() @@ -1159,8 +1161,7 @@ def main(args): image.save(image_filename) del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() + free_memory() # Handle the repository creation if accelerator.is_main_process: @@ -1728,6 +1729,10 @@ def main(args): device=accelerator.device, prompt=args.instance_prompt, ) + else: + prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( + prompts, text_encoders, tokenizers + ) # Convert images to latent space if args.cache_latents: