diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index b93c4327bb..d2defff51a 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -722,13 +722,13 @@ def main(args): else: raise ValueError(f"unexpected save model: {model.__class__}") - lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir) - LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_) + lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) + LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) LoraLoaderMixin.load_lora_into_text_encoder( - lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_one_ + lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_ ) LoraLoaderMixin.load_lora_into_text_encoder( - lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_two_ + lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_ ) accelerator.register_save_state_pre_hook(save_model_hook)