mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] attempt at fixing onetrainer lora. (#8242)
* attempt at fixing onetrainer lora. * fix
This commit is contained in:
@@ -226,6 +226,8 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
|
||||
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
|
||||
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
||||
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
||||
diffusers_name = diffusers_name.replace("text.projection", "text_projection")
|
||||
|
||||
if "self_attn" in diffusers_name:
|
||||
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
||||
te_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
@@ -243,6 +245,10 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
|
||||
else:
|
||||
te2_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
# OneTrainer specificity
|
||||
elif "text_projection" in diffusers_name and lora_name.startswith("lora_te2_"):
|
||||
te2_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
|
||||
if (is_te_dora_lora or is_te2_dora_lora) and lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
|
||||
dora_scale_key_to_replace_te = (
|
||||
@@ -270,7 +276,7 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
|
||||
network_alphas.update({new_name: alpha})
|
||||
|
||||
if len(state_dict) > 0:
|
||||
raise ValueError(f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}")
|
||||
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
|
||||
|
||||
logger.info("Kohya-style checkpoint detected.")
|
||||
unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
|
||||
|
||||
@@ -62,6 +62,8 @@ DIFFUSERS_TO_PEFT = {
|
||||
".out_proj.lora_linear_layer.down": ".out_proj.lora_A",
|
||||
".lora_linear_layer.up": ".lora_B",
|
||||
".lora_linear_layer.down": ".lora_A",
|
||||
"text_projection.lora.down.weight": "text_projection.lora_A.weight",
|
||||
"text_projection.lora.up.weight": "text_projection.lora_B.weight",
|
||||
}
|
||||
|
||||
DIFFUSERS_OLD_TO_PEFT = {
|
||||
|
||||
Reference in New Issue
Block a user