1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[LoRA] kijai wan lora support for I2V (#11588)

* testing

* testing

* testing

* testing

* testing

* i2v

* i2v

* device fix

* testing

* fix

* fix

* fix

* fix

* fix

* Apply style fixes

* empty commit

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Linoy Tsaban
2025-05-20 17:11:55 +03:00
committed by GitHub
parent 05c8b42b75
commit 5d4f723b57

View File

@@ -4813,22 +4813,43 @@ class WanLoraLoaderMixin(LoraBaseMixin):
if transformer.config.image_dim is None:
return state_dict
target_device = transformer.device
if any(k.startswith("transformer.blocks.") for k in state_dict):
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict})
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k})
is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
has_bias = any(".lora_B.bias" in k for k in state_dict)
if is_i2v_lora:
return state_dict
for i in range(num_blocks):
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
# These keys should exist if the block `i` was part of the T2V LoRA.
ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"
ref_key_lora_B = f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"
if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict:
continue
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"]
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device
)
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"]
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device
)
# If the original LoRA had biases (indicated by has_bias)
# AND the specific reference bias key exists for this block.
ref_key_lora_B_bias = f"transformer.blocks.{i}.attn2.to_k.lora_B.bias"
if has_bias and ref_key_lora_B_bias in state_dict:
ref_lora_B_bias_tensor = state_dict[ref_key_lora_B_bias]
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.bias"] = torch.zeros_like(
ref_lora_B_bias_tensor,
device=target_device,
)
return state_dict
def load_lora_weights(