1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[LoRA] attempt at fixing onetrainer lora. (#8242)

* attempt at fixing onetrainer lora.

* fix
This commit is contained in:
Sayak Paul
2024-05-28 23:55:54 +05:30
committed by GitHub
parent 80cfaebaa1
commit e6df8edadc
2 changed files with 9 additions and 1 deletions

View File

@@ -226,6 +226,8 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
diffusers_name = diffusers_name.replace("text.projection", "text_projection")
if "self_attn" in diffusers_name:
if lora_name.startswith(("lora_te_", "lora_te1_")):
te_state_dict[diffusers_name] = state_dict.pop(key)
@@ -243,6 +245,10 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
else:
te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# OneTrainer specificity
elif "text_projection" in diffusers_name and lora_name.startswith("lora_te2_"):
te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
if (is_te_dora_lora or is_te2_dora_lora) and lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
dora_scale_key_to_replace_te = (
@@ -270,7 +276,7 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
network_alphas.update({new_name: alpha})
if len(state_dict) > 0:
raise ValueError(f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}")
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
logger.info("Kohya-style checkpoint detected.")
unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}

View File

@@ -62,6 +62,8 @@ DIFFUSERS_TO_PEFT = {
".out_proj.lora_linear_layer.down": ".out_proj.lora_A",
".lora_linear_layer.up": ".lora_B",
".lora_linear_layer.down": ".lora_A",
"text_projection.lora.down.weight": "text_projection.lora_A.weight",
"text_projection.lora.up.weight": "text_projection.lora_B.weight",
}
DIFFUSERS_OLD_TO_PEFT = {