mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Update train_text_to_image_lora.py (#6144)
* Update train_text_to_image_lora.py * Fix typo? --------- Co-authored-by: M. Tolga Cangöz <46008593+standardAI@users.noreply.github.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -799,7 +799,8 @@ def main():
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
||||
unwrapped_unet = accelerator.unwrap_model(unet)
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
|
||||
|
||||
StableDiffusionPipeline.save_lora_weights(
|
||||
save_directory=save_path,
|
||||
@@ -864,7 +865,8 @@ def main():
|
||||
if accelerator.is_main_process:
|
||||
unet = unet.to(torch.float32)
|
||||
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unet)
|
||||
unwrapped_unet = accelerator.unwrap_model(unet)
|
||||
unet_lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
|
||||
StableDiffusionPipeline.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
unet_lora_layers=unet_lora_state_dict,
|
||||
|
||||
Reference in New Issue
Block a user