From e2ead7cdcc00859533e6bec7b0707a6fb0efef0a Mon Sep 17 00:00:00 2001 From: Leo Jiang <74156916+leisuzz@users.noreply.github.com> Date: Sat, 14 Sep 2024 18:59:38 +0800 Subject: [PATCH] Fix the issue on sd3 dreambooth w./w.t. lora training (#9419) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix dtype error * [bugfix] Fixed the issue on sd3 dreambooth training * [bugfix] Fixed the issue on sd3 dreambooth training --------- Co-authored-by: 蒋硕 Co-authored-by: Sayak Paul --- examples/dreambooth/train_dreambooth_flux.py | 5 ++++- examples/dreambooth/train_dreambooth_lora.py | 5 ++++- examples/dreambooth/train_dreambooth_lora_flux.py | 5 ++++- examples/dreambooth/train_dreambooth_lora_sd3.py | 5 ++++- examples/dreambooth/train_dreambooth_lora_sdxl.py | 5 ++++- examples/dreambooth/train_dreambooth_sd3.py | 5 ++++- 6 files changed, 24 insertions(+), 6 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index da571cc46c..8e0f4e09a4 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -154,13 +154,14 @@ def log_validation( accelerator, pipeline_args, epoch, + torch_dtype, is_final_validation=False, ): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(accelerator.device) + pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline.set_progress_bar_config(disable=True) # run inference @@ -1717,6 +1718,7 @@ def main(args): accelerator=accelerator, pipeline_args=pipeline_args, epoch=epoch, + torch_dtype=weight_dtype, ) if not args.train_text_encoder: del text_encoder_one, text_encoder_two @@ -1761,6 +1763,7 @@ def main(args): pipeline_args=pipeline_args, epoch=epoch, is_final_validation=True, + torch_dtype=weight_dtype, ) if args.push_to_hub: diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 331b2d6ab6..5d7d697bb2 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -122,6 +122,7 @@ def log_validation( accelerator, pipeline_args, epoch, + torch_dtype, is_final_validation=False, ): logger.info( @@ -141,7 +142,7 @@ def log_validation( pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) - pipeline = pipeline.to(accelerator.device) + pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline.set_progress_bar_config(disable=True) # run inference @@ -1360,6 +1361,7 @@ def main(args): accelerator, pipeline_args, epoch, + torch_dtype=weight_dtype, ) # Save the lora layers @@ -1402,6 +1404,7 @@ def main(args): pipeline_args, epoch, is_final_validation=True, + torch_dtype=weight_dtype, ) if args.push_to_hub: diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 48d669418f..bd5b46cc9f 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -170,13 +170,14 @@ def log_validation( accelerator, pipeline_args, epoch, + torch_dtype, is_final_validation=False, ): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(accelerator.device) + pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline.set_progress_bar_config(disable=True) # run inference @@ -1785,6 +1786,7 @@ def main(args): accelerator=accelerator, pipeline_args=pipeline_args, epoch=epoch, + torch_dtype=weight_dtype, ) if not args.train_text_encoder: del text_encoder_one, text_encoder_two @@ -1832,6 +1834,7 @@ def main(args): pipeline_args=pipeline_args, epoch=epoch, is_final_validation=True, + torch_dtype=weight_dtype, ) if args.push_to_hub: diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 17e6e107b0..3060813bbb 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -179,13 +179,14 @@ def log_validation( accelerator, pipeline_args, epoch, + torch_dtype, is_final_validation=False, ): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(accelerator.device) + pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline.set_progress_bar_config(disable=True) # run inference @@ -1788,6 +1789,7 @@ def main(args): accelerator=accelerator, pipeline_args=pipeline_args, epoch=epoch, + torch_dtype=weight_dtype, ) objs = [] if not args.train_text_encoder: @@ -1840,6 +1842,7 @@ def main(args): pipeline_args=pipeline_args, epoch=epoch, is_final_validation=True, + torch_dtype=weight_dtype, ) if args.push_to_hub: diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 17cc00db95..016464165c 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -180,6 +180,7 @@ def log_validation( accelerator, pipeline_args, epoch, + torch_dtype, is_final_validation=False, ): logger.info( @@ -201,7 +202,7 @@ def log_validation( pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) - pipeline = pipeline.to(accelerator.device) + pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline.set_progress_bar_config(disable=True) # run inference @@ -1890,6 +1891,7 @@ def main(args): accelerator, pipeline_args, epoch, + torch_dtype=weight_dtype, ) # Save the lora layers @@ -1955,6 +1957,7 @@ def main(args): pipeline_args, epoch, is_final_validation=True, + torch_dtype=weight_dtype, ) if args.push_to_hub: diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index 985814205d..c34024f478 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -157,13 +157,14 @@ def log_validation( accelerator, pipeline_args, epoch, + torch_dtype, is_final_validation=False, ): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(accelerator.device) + pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline.set_progress_bar_config(disable=True) # run inference @@ -1725,6 +1726,7 @@ def main(args): accelerator=accelerator, pipeline_args=pipeline_args, epoch=epoch, + torch_dtype=weight_dtype, ) if not args.train_text_encoder: del text_encoder_one, text_encoder_two, text_encoder_three @@ -1775,6 +1777,7 @@ def main(args): pipeline_args=pipeline_args, epoch=epoch, is_final_validation=True, + torch_dtype=weight_dtype, ) if args.push_to_hub: