mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Wan LoRAs] make T2V LoRAs compatible with Wan I2V (#11107)
* @hlky t2v->i2v
* Apply style fixes
* try with ones to not nullify layers
* fix method name
* revert to zeros
* add check to state_dict keys
* add comment
* copies fix
* Revert "copies fix"
This reverts commit 051f534d18.
* remove copied from
* Update src/diffusers/loaders/lora_pipeline.py
Co-authored-by: hlky <hlky@hlky.ac>
* Update src/diffusers/loaders/lora_pipeline.py
Co-authored-by: hlky <hlky@hlky.ac>
* update
* update
* Update src/diffusers/loaders/lora_pipeline.py
Co-authored-by: hlky <hlky@hlky.ac>
* Apply style fixes
---------
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Linoy <linoy@hf.co>
Co-authored-by: hlky <hlky@hlky.ac>
This commit is contained in:
@@ -4249,7 +4249,33 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
return state_dict
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
@classmethod
|
||||
def _maybe_expand_t2v_lora_for_i2v(
|
||||
cls,
|
||||
transformer: torch.nn.Module,
|
||||
state_dict,
|
||||
):
|
||||
if transformer.config.image_dim is None:
|
||||
return state_dict
|
||||
|
||||
if any(k.startswith("transformer.blocks.") for k in state_dict):
|
||||
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict})
|
||||
is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
|
||||
|
||||
if is_i2v_lora:
|
||||
return state_dict
|
||||
|
||||
for i in range(num_blocks):
|
||||
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
|
||||
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
|
||||
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"]
|
||||
)
|
||||
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
|
||||
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"]
|
||||
)
|
||||
|
||||
return state_dict
|
||||
|
||||
def load_lora_weights(
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
):
|
||||
@@ -4287,7 +4313,11 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
# convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
|
||||
state_dict = self._maybe_expand_t2v_lora_for_i2v(
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
state_dict=state_dict,
|
||||
)
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
Reference in New Issue
Block a user