From a34d97cef08f25685ebe165693c2511ad9ef8af1 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Wed, 19 Mar 2025 18:14:19 +0200 Subject: [PATCH] [Wan LoRAs] make T2V LoRAs compatible with Wan I2V (#11107) * @hlky t2v->i2v * Apply style fixes * try with ones to not nullify layers * fix method name * revert to zeros * add check to state_dict keys * add comment * copies fix * Revert "copies fix" This reverts commit 051f534d185c0ea065bf36a9926c4b48f496d429. * remove copied from * Update src/diffusers/loaders/lora_pipeline.py Co-authored-by: hlky * Update src/diffusers/loaders/lora_pipeline.py Co-authored-by: hlky * update * update * Update src/diffusers/loaders/lora_pipeline.py Co-authored-by: hlky * Apply style fixes --------- Co-authored-by: github-actions[bot] Co-authored-by: Linoy Co-authored-by: hlky --- src/diffusers/loaders/lora_pipeline.py | 34 ++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 160793ba1b..e522778dee 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4249,7 +4249,33 @@ class WanLoraLoaderMixin(LoraBaseMixin): return state_dict - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + @classmethod + def _maybe_expand_t2v_lora_for_i2v( + cls, + transformer: torch.nn.Module, + state_dict, + ): + if transformer.config.image_dim is None: + return state_dict + + if any(k.startswith("transformer.blocks.") for k in state_dict): + num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict}) + is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict) + + if is_i2v_lora: + return state_dict + + for i in range(num_blocks): + for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): + state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like( + state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"] + ) + state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like( + state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"] + ) + + return state_dict + def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): @@ -4287,7 +4313,11 @@ class WanLoraLoaderMixin(LoraBaseMixin): # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - + # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers + state_dict = self._maybe_expand_t2v_lora_for_i2v( + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + state_dict=state_dict, + ) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.")