diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 28ef8e63b5..81a927bd55 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4859,11 +4859,57 @@ class WanLoraLoaderMixin(LoraBaseMixin): transformer: torch.nn.Module, state_dict, ): - - if not hasattr(transformer, 'vace_blocks'): - return state_dict - target_device = transformer.device + if hasattr(transformer, 'vace_blocks'): + inferred_rank_for_vace = None + lora_weights_dtype_for_vace = next(iter(transformer.parameters())).dtype # Fallback dtype + + for k_lora_any, v_lora_tensor_any in state_dict.items(): + if k_lora_any.endswith(".lora_A.weight"): + inferred_rank_for_vace = v_lora_tensor_any.shape[0] + lora_weights_dtype_for_vace = v_lora_tensor_any.dtype + break # Found one, good enough for rank and dtype + + if inferred_rank_for_vace is not None: + # Determine if the LoRA format (as potentially modified by I2V expansion) includes bias + # This re-checks 'has_bias' based on the *current* state_dict. + current_lora_has_bias = any(".lora_B.bias" in k for k in state_dict.keys()) + + for i, vace_block_module_in_model in enumerate(transformer.vace_blocks): + # Specifically target proj_out as per the error message + if hasattr(vace_block_module_in_model, 'proj_out') and \ + isinstance(vace_block_module_in_model.proj_out, nn.Linear): + + proj_out_linear_layer_in_model = vace_block_module_in_model.proj_out + + vace_lora_A_key = f"vace_blocks.{i}.proj_out.lora_A.weight" + vace_lora_B_key = f"vace_blocks.{i}.proj_out.lora_B.weight" + + if vace_lora_A_key not in state_dict: + state_dict[vace_lora_A_key] = torch.zeros( + (inferred_rank_for_vace, proj_out_linear_layer_in_model.in_features), + device=target_device, dtype=lora_weights_dtype_for_vace + ) + + if vace_lora_B_key not in state_dict: + state_dict[vace_lora_B_key] = torch.zeros( + (proj_out_linear_layer_in_model.out_features, inferred_rank_for_vace), + device=target_device, dtype=lora_weights_dtype_for_vace + ) + + # Use 'current_lora_has_bias' to decide on padding bias for VACE blocks + if current_lora_has_bias and proj_out_linear_layer_in_model.bias is not None: + vace_lora_B_bias_key = f"vace_blocks.{i}.proj_out.lora_B.bias" + if vace_lora_B_bias_key not in state_dict: + state_dict[vace_lora_B_bias_key] = torch.zeros_like( + proj_out_linear_layer_in_model.bias, # Shape from model's bias + device=target_device # Dtype from model's bias implicitly by zeros_like + ) + + print("AFTER 2:", list(state_dict.keys())) + return state_dict + + return state_dict @@ -4915,6 +4961,10 @@ class WanLoraLoaderMixin(LoraBaseMixin): transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, state_dict=state_dict, ) + state_dict = self._maybe_expand_t2v_lora_for_vace( + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + state_dict=state_dict, + ) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.")