From c4b5d2ff6b529ac0f895cedb04fef5b25e89c412 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Sun, 24 Nov 2024 18:51:06 +0200 Subject: [PATCH] [SD3 dreambooth lora] smol fix to checkpoint saving (#9993) * smol change to fix checkpoint saving & resuming (as done in train_dreambooth_sd3.py) * style * modify comment to explain reasoning behind hidden size check --- examples/dreambooth/train_dreambooth_lora_sd3.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index dcf093a94c..3f721e56ad 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1294,10 +1294,13 @@ def main(args): for model in models: if isinstance(model, type(unwrap_model(transformer))): transformer_lora_layers_to_save = get_peft_model_state_dict(model) - elif isinstance(model, type(unwrap_model(text_encoder_one))): - text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) - elif isinstance(model, type(unwrap_model(text_encoder_two))): - text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model) + elif isinstance(model, type(unwrap_model(text_encoder_one))): # or text_encoder_two + # both text encoders are of the same class, so we check hidden size to distinguish between the two + hidden_size = unwrap_model(model).config.hidden_size + if hidden_size == 768: + text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) + elif hidden_size == 1280: + text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model) else: raise ValueError(f"unexpected save model: {model.__class__}")