1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[SD3 dreambooth lora] smol fix to checkpoint saving (#9993)

* smol change to fix checkpoint saving & resuming (as done in train_dreambooth_sd3.py)

* style

* modify comment to explain reasoning behind hidden size check
This commit is contained in:
Linoy Tsaban
2024-11-24 18:51:06 +02:00
committed by GitHub
parent 7ac6e286ee
commit c4b5d2ff6b

View File

@@ -1294,10 +1294,13 @@ def main(args):
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(unwrap_model(text_encoder_one))): # or text_encoder_two
# both text encoders are of the same class, so we check hidden size to distinguish between the two
hidden_size = unwrap_model(model).config.hidden_size
if hidden_size == 768:
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
elif hidden_size == 1280:
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")