1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Wan 2.2 LoRA] add support for 2nd transformer lora loading + wan 2.2 lightx2v lora (#12074)

* add alpha

* load into 2nd transformer

* Update src/diffusers/loaders/lora_conversion_utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/loaders/lora_conversion_utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* pr comments

* pr comments

* pr comments

* fix

* fix

* Apply style fixes

* fix copies

* fix

* fix copies

* Update src/diffusers/loaders/lora_pipeline.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* revert change

* revert change

* fix copies

* up

* fix

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: linoy <linoy@hf.co>
This commit is contained in:
Linoy Tsaban
2025-08-19 06:02:39 +03:00
committed by GitHub
parent 8cc528c5e7
commit 8d1de40891
4 changed files with 147 additions and 52 deletions

View File

@@ -333,6 +333,8 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
- Wan 2.1 and 2.2 support using [LightX2V LoRAs](https://huggingface.co/Kijai/WanVideo_comfy/tree/main/Lightx2v) to speed up inference. Using them on Wan 2.2 is slightly more involed. Refer to [this code snippet](https://github.com/huggingface/diffusers/pull/12040#issuecomment-3144185272) to learn more.
- Wan 2.2 has two denoisers. By default, LoRAs are only loaded into the first denoiser. One can set `load_into_transformer_2=True` to load LoRAs into the second denoiser. Refer to [this](https://github.com/huggingface/diffusers/pull/12074#issue-3292620048) and [this](https://github.com/huggingface/diffusers/pull/12074#issuecomment-3155896144) examples to learn more.
## WanPipeline
[[autodoc]] WanPipeline

View File

@@ -754,7 +754,11 @@ class LoraBaseMixin:
# Decompose weights into weights for denoiser and text encoders.
_component_adapter_weights = {}
for component in self._lora_loadable_modules:
model = getattr(self, component)
model = getattr(self, component, None)
# To guard for cases like Wan. In Wan2.1 and WanVace, we have a single denoiser.
# Whereas in Wan 2.2, we have two denoisers.
if model is None:
continue
for adapter_name, weights in zip(adapter_names, adapter_weights):
if isinstance(weights, dict):

View File

@@ -1833,6 +1833,17 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict
)
def get_alpha_scales(down_weight, alpha_key):
rank = down_weight.shape[0]
alpha = original_state_dict.pop(alpha_key).item()
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
return scale_down, scale_up
for key in list(original_state_dict.keys()):
if key.endswith((".diff", ".diff_b")) and "norm" in key:
# NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
@@ -1852,15 +1863,26 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
for i in range(min_block, max_block + 1):
# Self-attention
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
original_key = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
converted_key = f"blocks.{i}.attn1.{c}.lora_A.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
alpha_key = f"blocks.{i}.self_attn.{o}.alpha"
has_alpha = alpha_key in original_state_dict
original_key_A = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
converted_key_A = f"blocks.{i}.attn1.{c}.lora_A.weight"
original_key = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
converted_key = f"blocks.{i}.attn1.{c}.lora_B.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
original_key_B = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
converted_key_B = f"blocks.{i}.attn1.{c}.lora_B.weight"
if has_alpha:
down_weight = original_state_dict.pop(original_key_A)
up_weight = original_state_dict.pop(original_key_B)
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
converted_state_dict[converted_key_A] = down_weight * scale_down
converted_state_dict[converted_key_B] = up_weight * scale_up
else:
if original_key_A in original_state_dict:
converted_state_dict[converted_key_A] = original_state_dict.pop(original_key_A)
if original_key_B in original_state_dict:
converted_state_dict[converted_key_B] = original_state_dict.pop(original_key_B)
original_key = f"blocks.{i}.self_attn.{o}.diff_b"
converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias"
@@ -1869,15 +1891,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
# Cross-attention
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
has_alpha = alpha_key in original_state_dict
original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
if original_key_A in original_state_dict:
down_weight = original_state_dict.pop(original_key_A)
converted_state_dict[converted_key_A] = down_weight
if original_key_B in original_state_dict:
up_weight = original_state_dict.pop(original_key_B)
converted_state_dict[converted_key_B] = up_weight
if has_alpha:
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
converted_state_dict[converted_key_A] *= scale_down
converted_state_dict[converted_key_B] *= scale_up
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
@@ -1886,15 +1917,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
if is_i2v_lora:
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
has_alpha = alpha_key in original_state_dict
original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
if original_key_A in original_state_dict:
down_weight = original_state_dict.pop(original_key_A)
converted_state_dict[converted_key_A] = down_weight
if original_key_B in original_state_dict:
up_weight = original_state_dict.pop(original_key_B)
converted_state_dict[converted_key_B] = up_weight
if has_alpha:
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
converted_state_dict[converted_key_A] *= scale_down
converted_state_dict[converted_key_B] *= scale_up
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
@@ -1903,15 +1943,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
# FFN
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
original_key = f"blocks.{i}.{o}.{lora_down_key}.weight"
converted_key = f"blocks.{i}.ffn.{c}.lora_A.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
alpha_key = f"blocks.{i}.{o}.alpha"
has_alpha = alpha_key in original_state_dict
original_key_A = f"blocks.{i}.{o}.{lora_down_key}.weight"
converted_key_A = f"blocks.{i}.ffn.{c}.lora_A.weight"
original_key = f"blocks.{i}.{o}.{lora_up_key}.weight"
converted_key = f"blocks.{i}.ffn.{c}.lora_B.weight"
if original_key in original_state_dict:
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
original_key_B = f"blocks.{i}.{o}.{lora_up_key}.weight"
converted_key_B = f"blocks.{i}.ffn.{c}.lora_B.weight"
if original_key_A in original_state_dict:
down_weight = original_state_dict.pop(original_key_A)
converted_state_dict[converted_key_A] = down_weight
if original_key_B in original_state_dict:
up_weight = original_state_dict.pop(original_key_B)
converted_state_dict[converted_key_B] = up_weight
if has_alpha:
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
converted_state_dict[converted_key_A] *= scale_down
converted_state_dict[converted_key_B] *= scale_up
original_key = f"blocks.{i}.{o}.diff_b"
converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias"

View File

@@ -5065,7 +5065,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
"""
_lora_loadable_modules = ["transformer"]
_lora_loadable_modules = ["transformer", "transformer_2"]
transformer_name = TRANSFORMER_NAME
@classmethod
@@ -5270,15 +5270,35 @@ class WanLoraLoaderMixin(LoraBaseMixin):
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
if load_into_transformer_2:
if not hasattr(self, "transformer_2"):
raise AttributeError(
f"'{type(self).__name__}' object has no attribute transformer_2"
"Note that Wan2.1 models do not have a transformer_2 component."
"Ensure the model has a transformer_2 component before setting load_into_transformer_2=True."
)
self.load_lora_into_transformer(
state_dict,
transformer=self.transformer_2,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
else:
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name)
if not hasattr(self, "transformer")
else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
@@ -5668,15 +5688,35 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
if load_into_transformer_2:
if not hasattr(self, "transformer_2"):
raise AttributeError(
f"'{type(self).__name__}' object has no attribute transformer_2"
"Note that Wan2.1 models do not have a transformer_2 component."
"Ensure the model has a transformer_2 component before setting load_into_transformer_2=True."
)
self.load_lora_into_transformer(
state_dict,
transformer=self.transformer_2,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
else:
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name)
if not hasattr(self, "transformer")
else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SkyReelsV2Transformer3DModel