1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

fix state dict serialization

This commit is contained in:
sayakpaul
2023-11-14 11:04:02 +05:30
parent 4135414907
commit 3b066d2657

View File

@@ -37,7 +37,6 @@ from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict
from safetensors.torch import save_file
from torchvision import transforms
from torchvision.transforms.functional import crop
from tqdm.auto import tqdm
@@ -134,6 +133,10 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):
peft_state_dict = get_peft_model_state_dict(unet, adapter_name="default")
diffusers_state_dict = convert_state_dict_to_diffusers(peft_state_dict)
diffusers_state_dict = {
f"{pipeline.unet_name}.{module_name}": param for module_name, param in diffusers_state_dict.items()
}
pipeline.load_lora_weights(diffusers_state_dict)
pipeline.fuse_lora()
@@ -1346,7 +1349,7 @@ def main(args):
# save_file(kohya_state_dict, os.path.join(args.output_dir, "pytorch_lora_weights.safetensors"))
peft_state_dict = get_peft_model_state_dict(unet, adapter_name="default")
diffusers_state_dict = convert_state_dict_to_diffusers(peft_state_dict)
save_file(diffusers_state_dict, os.path.join(args.output_dir, "pytorch_lora_weights.safetensors"))
StableDiffusionXLPipeline.save_lora_weights(args.output_dir, diffusers_state_dict)
if args.push_to_hub:
upload_folder(