mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] kijai wan lora support for I2V (#11588)
* testing * testing * testing * testing * testing * i2v * i2v * device fix * testing * fix * fix * fix * fix * fix * Apply style fixes * empty commit --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -4813,22 +4813,43 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
if transformer.config.image_dim is None:
|
||||
return state_dict
|
||||
|
||||
target_device = transformer.device
|
||||
|
||||
if any(k.startswith("transformer.blocks.") for k in state_dict):
|
||||
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict})
|
||||
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k})
|
||||
is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
|
||||
has_bias = any(".lora_B.bias" in k for k in state_dict)
|
||||
|
||||
if is_i2v_lora:
|
||||
return state_dict
|
||||
|
||||
for i in range(num_blocks):
|
||||
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
|
||||
# These keys should exist if the block `i` was part of the T2V LoRA.
|
||||
ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"
|
||||
ref_key_lora_B = f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"
|
||||
|
||||
if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict:
|
||||
continue
|
||||
|
||||
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
|
||||
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"]
|
||||
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device
|
||||
)
|
||||
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
|
||||
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"]
|
||||
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device
|
||||
)
|
||||
|
||||
# If the original LoRA had biases (indicated by has_bias)
|
||||
# AND the specific reference bias key exists for this block.
|
||||
|
||||
ref_key_lora_B_bias = f"transformer.blocks.{i}.attn2.to_k.lora_B.bias"
|
||||
if has_bias and ref_key_lora_B_bias in state_dict:
|
||||
ref_lora_B_bias_tensor = state_dict[ref_key_lora_B_bias]
|
||||
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.bias"] = torch.zeros_like(
|
||||
ref_lora_B_bias_tensor,
|
||||
device=target_device,
|
||||
)
|
||||
|
||||
return state_dict
|
||||
|
||||
def load_lora_weights(
|
||||
|
||||
Reference in New Issue
Block a user