mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
fix state dict serialization
This commit is contained in:
@@ -37,7 +37,6 @@ from datasets import load_dataset
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from peft import LoraConfig, get_peft_model, get_peft_model_state_dict
|
||||
from safetensors.torch import save_file
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import crop
|
||||
from tqdm.auto import tqdm
|
||||
@@ -134,6 +133,10 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):
|
||||
|
||||
peft_state_dict = get_peft_model_state_dict(unet, adapter_name="default")
|
||||
diffusers_state_dict = convert_state_dict_to_diffusers(peft_state_dict)
|
||||
diffusers_state_dict = {
|
||||
f"{pipeline.unet_name}.{module_name}": param for module_name, param in diffusers_state_dict.items()
|
||||
}
|
||||
|
||||
pipeline.load_lora_weights(diffusers_state_dict)
|
||||
pipeline.fuse_lora()
|
||||
|
||||
@@ -1346,7 +1349,7 @@ def main(args):
|
||||
# save_file(kohya_state_dict, os.path.join(args.output_dir, "pytorch_lora_weights.safetensors"))
|
||||
peft_state_dict = get_peft_model_state_dict(unet, adapter_name="default")
|
||||
diffusers_state_dict = convert_state_dict_to_diffusers(peft_state_dict)
|
||||
save_file(diffusers_state_dict, os.path.join(args.output_dir, "pytorch_lora_weights.safetensors"))
|
||||
StableDiffusionXLPipeline.save_lora_weights(args.output_dir, diffusers_state_dict)
|
||||
|
||||
if args.push_to_hub:
|
||||
upload_folder(
|
||||
|
||||
Reference in New Issue
Block a user