diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index ae904adb19..0c03584068 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1289,8 +1289,8 @@ def main(args): models = [unet_] if args.train_text_encoder: models.extend([text_encoder_one_, text_encoder_two_]) - # only upcast trainable parameters (LoRA) into fp32 - cast_training_params(models) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py index e8c9cb796a..00f95509be 100644 --- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py @@ -1363,8 +1363,8 @@ def main(args): models = [unet_] if args.train_text_encoder: models.extend([text_encoder_one_, text_encoder_two_]) - # only upcast trainable parameters (LoRA) into fp32 - cast_training_params(models) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook)