mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Wan 2.2 LoRA] add support for 2nd transformer lora loading + wan 2.2 lightx2v lora (#12074)
* add alpha * load into 2nd transformer * Update src/diffusers/loaders/lora_conversion_utils.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update src/diffusers/loaders/lora_conversion_utils.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * pr comments * pr comments * pr comments * fix * fix * Apply style fixes * fix copies * fix * fix copies * Update src/diffusers/loaders/lora_pipeline.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * revert change * revert change * fix copies * up * fix --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: linoy <linoy@hf.co>
This commit is contained in:
@@ -333,6 +333,8 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
|
||||
|
||||
- Wan 2.1 and 2.2 support using [LightX2V LoRAs](https://huggingface.co/Kijai/WanVideo_comfy/tree/main/Lightx2v) to speed up inference. Using them on Wan 2.2 is slightly more involed. Refer to [this code snippet](https://github.com/huggingface/diffusers/pull/12040#issuecomment-3144185272) to learn more.
|
||||
|
||||
- Wan 2.2 has two denoisers. By default, LoRAs are only loaded into the first denoiser. One can set `load_into_transformer_2=True` to load LoRAs into the second denoiser. Refer to [this](https://github.com/huggingface/diffusers/pull/12074#issue-3292620048) and [this](https://github.com/huggingface/diffusers/pull/12074#issuecomment-3155896144) examples to learn more.
|
||||
|
||||
## WanPipeline
|
||||
|
||||
[[autodoc]] WanPipeline
|
||||
|
||||
@@ -754,7 +754,11 @@ class LoraBaseMixin:
|
||||
# Decompose weights into weights for denoiser and text encoders.
|
||||
_component_adapter_weights = {}
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component)
|
||||
model = getattr(self, component, None)
|
||||
# To guard for cases like Wan. In Wan2.1 and WanVace, we have a single denoiser.
|
||||
# Whereas in Wan 2.2, we have two denoisers.
|
||||
if model is None:
|
||||
continue
|
||||
|
||||
for adapter_name, weights in zip(adapter_names, adapter_weights):
|
||||
if isinstance(weights, dict):
|
||||
|
||||
@@ -1833,6 +1833,17 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict
|
||||
)
|
||||
|
||||
def get_alpha_scales(down_weight, alpha_key):
|
||||
rank = down_weight.shape[0]
|
||||
alpha = original_state_dict.pop(alpha_key).item()
|
||||
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
return scale_down, scale_up
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
if key.endswith((".diff", ".diff_b")) and "norm" in key:
|
||||
# NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
|
||||
@@ -1852,15 +1863,26 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
for i in range(min_block, max_block + 1):
|
||||
# Self-attention
|
||||
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
||||
original_key = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
|
||||
converted_key = f"blocks.{i}.attn1.{c}.lora_A.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
alpha_key = f"blocks.{i}.self_attn.{o}.alpha"
|
||||
has_alpha = alpha_key in original_state_dict
|
||||
original_key_A = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
|
||||
converted_key_A = f"blocks.{i}.attn1.{c}.lora_A.weight"
|
||||
|
||||
original_key = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
|
||||
converted_key = f"blocks.{i}.attn1.{c}.lora_B.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
original_key_B = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
|
||||
converted_key_B = f"blocks.{i}.attn1.{c}.lora_B.weight"
|
||||
|
||||
if has_alpha:
|
||||
down_weight = original_state_dict.pop(original_key_A)
|
||||
up_weight = original_state_dict.pop(original_key_B)
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
||||
converted_state_dict[converted_key_A] = down_weight * scale_down
|
||||
converted_state_dict[converted_key_B] = up_weight * scale_up
|
||||
|
||||
else:
|
||||
if original_key_A in original_state_dict:
|
||||
converted_state_dict[converted_key_A] = original_state_dict.pop(original_key_A)
|
||||
if original_key_B in original_state_dict:
|
||||
converted_state_dict[converted_key_B] = original_state_dict.pop(original_key_B)
|
||||
|
||||
original_key = f"blocks.{i}.self_attn.{o}.diff_b"
|
||||
converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias"
|
||||
@@ -1869,15 +1891,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
|
||||
# Cross-attention
|
||||
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
||||
original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
|
||||
converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
|
||||
has_alpha = alpha_key in original_state_dict
|
||||
original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
|
||||
converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
|
||||
|
||||
original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
|
||||
converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
|
||||
converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
|
||||
|
||||
if original_key_A in original_state_dict:
|
||||
down_weight = original_state_dict.pop(original_key_A)
|
||||
converted_state_dict[converted_key_A] = down_weight
|
||||
if original_key_B in original_state_dict:
|
||||
up_weight = original_state_dict.pop(original_key_B)
|
||||
converted_state_dict[converted_key_B] = up_weight
|
||||
if has_alpha:
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
||||
converted_state_dict[converted_key_A] *= scale_down
|
||||
converted_state_dict[converted_key_B] *= scale_up
|
||||
|
||||
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
|
||||
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
|
||||
@@ -1886,15 +1917,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
|
||||
if is_i2v_lora:
|
||||
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
|
||||
original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
|
||||
converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
|
||||
has_alpha = alpha_key in original_state_dict
|
||||
original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
|
||||
converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
|
||||
|
||||
original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
|
||||
converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
|
||||
converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
|
||||
|
||||
if original_key_A in original_state_dict:
|
||||
down_weight = original_state_dict.pop(original_key_A)
|
||||
converted_state_dict[converted_key_A] = down_weight
|
||||
if original_key_B in original_state_dict:
|
||||
up_weight = original_state_dict.pop(original_key_B)
|
||||
converted_state_dict[converted_key_B] = up_weight
|
||||
if has_alpha:
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
||||
converted_state_dict[converted_key_A] *= scale_down
|
||||
converted_state_dict[converted_key_B] *= scale_up
|
||||
|
||||
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
|
||||
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
|
||||
@@ -1903,15 +1943,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
|
||||
# FFN
|
||||
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
|
||||
original_key = f"blocks.{i}.{o}.{lora_down_key}.weight"
|
||||
converted_key = f"blocks.{i}.ffn.{c}.lora_A.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
alpha_key = f"blocks.{i}.{o}.alpha"
|
||||
has_alpha = alpha_key in original_state_dict
|
||||
original_key_A = f"blocks.{i}.{o}.{lora_down_key}.weight"
|
||||
converted_key_A = f"blocks.{i}.ffn.{c}.lora_A.weight"
|
||||
|
||||
original_key = f"blocks.{i}.{o}.{lora_up_key}.weight"
|
||||
converted_key = f"blocks.{i}.ffn.{c}.lora_B.weight"
|
||||
if original_key in original_state_dict:
|
||||
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
|
||||
original_key_B = f"blocks.{i}.{o}.{lora_up_key}.weight"
|
||||
converted_key_B = f"blocks.{i}.ffn.{c}.lora_B.weight"
|
||||
|
||||
if original_key_A in original_state_dict:
|
||||
down_weight = original_state_dict.pop(original_key_A)
|
||||
converted_state_dict[converted_key_A] = down_weight
|
||||
if original_key_B in original_state_dict:
|
||||
up_weight = original_state_dict.pop(original_key_B)
|
||||
converted_state_dict[converted_key_B] = up_weight
|
||||
if has_alpha:
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
||||
converted_state_dict[converted_key_A] *= scale_down
|
||||
converted_state_dict[converted_key_B] *= scale_up
|
||||
|
||||
original_key = f"blocks.{i}.{o}.diff_b"
|
||||
converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias"
|
||||
|
||||
@@ -5065,7 +5065,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["transformer"]
|
||||
_lora_loadable_modules = ["transformer", "transformer_2"]
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
|
||||
@classmethod
|
||||
@@ -5270,15 +5270,35 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
|
||||
if load_into_transformer_2:
|
||||
if not hasattr(self, "transformer_2"):
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute transformer_2"
|
||||
"Note that Wan2.1 models do not have a transformer_2 component."
|
||||
"Ensure the model has a transformer_2 component before setting load_into_transformer_2=True."
|
||||
)
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=self.transformer_2,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
else:
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name)
|
||||
if not hasattr(self, "transformer")
|
||||
else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
|
||||
@@ -5668,15 +5688,35 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
|
||||
if load_into_transformer_2:
|
||||
if not hasattr(self, "transformer_2"):
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute transformer_2"
|
||||
"Note that Wan2.1 models do not have a transformer_2 component."
|
||||
"Ensure the model has a transformer_2 component before setting load_into_transformer_2=True."
|
||||
)
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=self.transformer_2,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
else:
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name)
|
||||
if not hasattr(self, "transformer")
|
||||
else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SkyReelsV2Transformer3DModel
|
||||
|
||||
Reference in New Issue
Block a user