From 23a2cd33379015b8d59f235dbf8879272175adf9 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 18 Jun 2024 14:57:34 +0100 Subject: [PATCH] [LoRA] training fix the position of param casting when loading them (#8460) fix the position of param casting when loading them --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 4 ++-- .../dreambooth/train_dreambooth_lora_sdxl.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index ae904adb19..0c03584068 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1289,8 +1289,8 @@ def main(args): models = [unet_] if args.train_text_encoder: models.extend([text_encoder_one_, text_encoder_two_]) - # only upcast trainable parameters (LoRA) into fp32 - cast_training_params(models) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py index e8c9cb796a..00f95509be 100644 --- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py @@ -1363,8 +1363,8 @@ def main(args): models = [unet_] if args.train_text_encoder: models.extend([text_encoder_one_, text_encoder_two_]) - # only upcast trainable parameters (LoRA) into fp32 - cast_training_params(models) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook)