mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix: checkpoint save issue in advanced dreambooth lora sdxl script (#8926)
Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
This commit is contained in:
@@ -1605,13 +1605,15 @@ def main(args):
|
||||
if isinstance(model, type(unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
if args.train_text_encoder:
|
||||
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user