mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] fix peft state dict parsing (#10532)
* fix peft state dict parsing * updates
This commit is contained in:
@@ -519,7 +519,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
remaining_keys = list(sds_sd.keys())
|
||||
te_state_dict = {}
|
||||
if remaining_keys:
|
||||
if not all(k.startswith("lora_te1") for k in remaining_keys):
|
||||
if not all(k.startswith("lora_te") for k in remaining_keys):
|
||||
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
|
||||
for key in remaining_keys:
|
||||
if not key.endswith("lora_down.weight"):
|
||||
@@ -558,6 +558,88 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
new_state_dict = {**ait_sd, **te_state_dict}
|
||||
return new_state_dict
|
||||
|
||||
def _convert_mixture_state_dict_to_diffusers(state_dict):
|
||||
new_state_dict = {}
|
||||
|
||||
def _convert(original_key, diffusers_key, state_dict, new_state_dict):
|
||||
down_key = f"{original_key}.lora_down.weight"
|
||||
down_weight = state_dict.pop(down_key)
|
||||
lora_rank = down_weight.shape[0]
|
||||
|
||||
up_weight_key = f"{original_key}.lora_up.weight"
|
||||
up_weight = state_dict.pop(up_weight_key)
|
||||
|
||||
alpha_key = f"{original_key}.alpha"
|
||||
alpha = state_dict.pop(alpha_key)
|
||||
|
||||
# scale weight by alpha and dim
|
||||
scale = alpha / lora_rank
|
||||
# calculate scale_down and scale_up
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
down_weight = down_weight * scale_down
|
||||
up_weight = up_weight * scale_up
|
||||
|
||||
diffusers_down_key = f"{diffusers_key}.lora_A.weight"
|
||||
new_state_dict[diffusers_down_key] = down_weight
|
||||
new_state_dict[diffusers_down_key.replace(".lora_A.", ".lora_B.")] = up_weight
|
||||
|
||||
all_unique_keys = {
|
||||
k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "") for k in state_dict
|
||||
}
|
||||
all_unique_keys = sorted(all_unique_keys)
|
||||
assert all("lora_transformer_" in k for k in all_unique_keys), f"{all_unique_keys=}"
|
||||
|
||||
for k in all_unique_keys:
|
||||
if k.startswith("lora_transformer_single_transformer_blocks_"):
|
||||
i = int(k.split("lora_transformer_single_transformer_blocks_")[-1].split("_")[0])
|
||||
diffusers_key = f"single_transformer_blocks.{i}"
|
||||
elif k.startswith("lora_transformer_transformer_blocks_"):
|
||||
i = int(k.split("lora_transformer_transformer_blocks_")[-1].split("_")[0])
|
||||
diffusers_key = f"transformer_blocks.{i}"
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if "attn_" in k:
|
||||
if "_to_out_0" in k:
|
||||
diffusers_key += ".attn.to_out.0"
|
||||
elif "_to_add_out" in k:
|
||||
diffusers_key += ".attn.to_add_out"
|
||||
elif any(qkv in k for qkv in ["to_q", "to_k", "to_v"]):
|
||||
remaining = k.split("attn_")[-1]
|
||||
diffusers_key += f".attn.{remaining}"
|
||||
elif any(add_qkv in k for add_qkv in ["add_q_proj", "add_k_proj", "add_v_proj"]):
|
||||
remaining = k.split("attn_")[-1]
|
||||
diffusers_key += f".attn.{remaining}"
|
||||
|
||||
if diffusers_key == f"transformer_blocks.{i}":
|
||||
print(k, diffusers_key)
|
||||
_convert(k, diffusers_key, state_dict, new_state_dict)
|
||||
|
||||
if len(state_dict) > 0:
|
||||
raise ValueError(
|
||||
f"Expected an empty state dict at this point but its has these keys which couldn't be parsed: {list(state_dict.keys())}."
|
||||
)
|
||||
|
||||
new_state_dict = {f"transformer.{k}": v for k, v in new_state_dict.items()}
|
||||
return new_state_dict
|
||||
|
||||
# This is weird.
|
||||
# https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=sharp_detailed_foot.safetensors
|
||||
# has both `peft` and non-peft state dict.
|
||||
has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict)
|
||||
if has_peft_state_dict:
|
||||
state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
|
||||
return state_dict
|
||||
# Another weird one.
|
||||
has_mixture = any(
|
||||
k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict
|
||||
)
|
||||
if has_mixture:
|
||||
return _convert_mixture_state_dict_to_diffusers(state_dict)
|
||||
return _convert_sd_scripts_to_ai_toolkit(state_dict)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user