From 7710415baf6fdff3f0ba024959376fce5efd1711 Mon Sep 17 00:00:00 2001 From: akbaig <57063370+akbaig@users.noreply.github.com> Date: Tue, 23 Jul 2024 12:14:56 +0300 Subject: [PATCH] fix: checkpoint save issue in advanced dreambooth lora sdxl script (#8926) Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> --- .../train_dreambooth_lora_sdxl_advanced.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 9d06ce6cba..075298be00 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -1605,13 +1605,15 @@ def main(args): if isinstance(model, type(unwrap_model(unet))): unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) elif isinstance(model, type(unwrap_model(text_encoder_one))): - text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( - get_peft_model_state_dict(model) - ) + if args.train_text_encoder: + text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( + get_peft_model_state_dict(model) + ) elif isinstance(model, type(unwrap_model(text_encoder_two))): - text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( - get_peft_model_state_dict(model) - ) + if args.train_text_encoder: + text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( + get_peft_model_state_dict(model) + ) else: raise ValueError(f"unexpected save model: {model.__class__}")