1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[WIP][LoRA] start supporting kijai wan lora. (#11579)

* start supporting kijai wan lora.

* diff_b keys.

* Apply suggestions from code review

Co-authored-by: Aryan <aryan@huggingface.co>

* merge ready

---------

Co-authored-by: Aryan <aryan@huggingface.co>
This commit is contained in:
Sayak Paul
2025-05-19 20:47:44 +05:30
committed by GitHub
parent ceb7af277c
commit 00f9273da2

View File

@@ -1596,48 +1596,131 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
converted_state_dict = {}
original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict})
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict if "blocks." in k})
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down"
lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up"
diff_keys = [k for k in original_state_dict if k.endswith((".diff_b", ".diff"))]
if diff_keys:
for diff_k in diff_keys:
param = original_state_dict[diff_k]
all_zero = torch.all(param == 0).item()
if all_zero:
logger.debug(f"Removed {diff_k} key from the state dict as it's all zeros.")
original_state_dict.pop(diff_k)
# For the `diff_b` keys, we treat them as lora_bias.
# https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
for i in range(num_blocks):
# Self-attention
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop(
f"blocks.{i}.self_attn.{o}.lora_A.weight"
f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
)
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop(
f"blocks.{i}.self_attn.{o}.lora_B.weight"
f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
)
if f"blocks.{i}.self_attn.{o}.diff_b" in original_state_dict:
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.bias"] = original_state_dict.pop(
f"blocks.{i}.self_attn.{o}.diff_b"
)
# Cross-attention
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.lora_A.weight"
f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
)
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
)
if f"blocks.{i}.cross_attn.{o}.diff_b" in original_state_dict:
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.bias"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.diff_b"
)
if is_i2v_lora:
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.lora_A.weight"
f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
)
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
)
if f"blocks.{i}.cross_attn.{o}.diff_b" in original_state_dict:
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.bias"] = original_state_dict.pop(
f"blocks.{i}.cross_attn.{o}.diff_b"
)
# FFN
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop(
f"blocks.{i}.{o}.lora_A.weight"
f"blocks.{i}.{o}.{lora_down_key}.weight"
)
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop(
f"blocks.{i}.{o}.lora_B.weight"
f"blocks.{i}.{o}.{lora_up_key}.weight"
)
if f"blocks.{i}.{o}.diff_b" in original_state_dict:
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.bias"] = original_state_dict.pop(
f"blocks.{i}.{o}.diff_b"
)
# Remaining.
if original_state_dict:
if any("time_projection" in k for k in original_state_dict):
converted_state_dict["condition_embedder.time_proj.lora_A.weight"] = original_state_dict.pop(
f"time_projection.1.{lora_down_key}.weight"
)
converted_state_dict["condition_embedder.time_proj.lora_B.weight"] = original_state_dict.pop(
f"time_projection.1.{lora_up_key}.weight"
)
if "time_projection.1.diff_b" in original_state_dict:
converted_state_dict["condition_embedder.time_proj.lora_B.bias"] = original_state_dict.pop(
"time_projection.1.diff_b"
)
if any("head.head" in k for k in state_dict):
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
f"head.head.{lora_down_key}.weight"
)
converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(f"head.head.{lora_up_key}.weight")
if "head.head.diff_b" in original_state_dict:
converted_state_dict["proj_out.lora_B.bias"] = original_state_dict.pop("head.head.diff_b")
for text_time in ["text_embedding", "time_embedding"]:
if any(text_time in k for k in original_state_dict):
for b_n in [0, 2]:
diffusers_b_n = 1 if b_n == 0 else 2
diffusers_name = (
"condition_embedder.text_embedder"
if text_time == "text_embedding"
else "condition_embedder.time_embedder"
)
if any(f"{text_time}.{b_n}" in k for k in original_state_dict):
converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_A.weight"] = (
original_state_dict.pop(f"{text_time}.{b_n}.{lora_down_key}.weight")
)
converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_B.weight"] = (
original_state_dict.pop(f"{text_time}.{b_n}.{lora_up_key}.weight")
)
if f"{text_time}.{b_n}.diff_b" in original_state_dict:
converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_B.bias"] = (
original_state_dict.pop(f"{text_time}.{b_n}.diff_b")
)
if len(original_state_dict) > 0:
raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
diff = all(".diff" in k for k in original_state_dict)
if diff:
diff_keys = {k for k in original_state_dict if k.endswith(".diff")}
if not all("lora" not in k for k in diff_keys):
raise ValueError
logger.info(
"The remaining `state_dict` contains `diff` keys which we do not handle yet. If you see performance issues, please file an issue: "
"https://github.com/huggingface/diffusers//issues/new"
)
else:
raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
for key in list(converted_state_dict.keys()):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)