mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[SD3 dreambooth lora] smol fix to checkpoint saving (#9993)
* smol change to fix checkpoint saving & resuming (as done in train_dreambooth_sd3.py) * style * modify comment to explain reasoning behind hidden size check
This commit is contained in:
@@ -1294,10 +1294,13 @@ def main(args):
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_one))): # or text_encoder_two
|
||||
# both text encoders are of the same class, so we check hidden size to distinguish between the two
|
||||
hidden_size = unwrap_model(model).config.hidden_size
|
||||
if hidden_size == 768:
|
||||
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif hidden_size == 1280:
|
||||
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user