diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index df649a381c..1bf738d14c 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4833,10 +4833,10 @@ class WanLoraLoaderMixin(LoraBaseMixin): continue state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like( - state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device + state_dict[ref_key_lora_A], device=target_device # Using original ref_key_lora_A ) state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like( - state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device + state_dict[ref_key_lora_B], device=target_device # Using original ref_key_lora_B ) # If the original LoRA had biases (indicated by has_bias) @@ -4850,6 +4850,47 @@ class WanLoraLoaderMixin(LoraBaseMixin): device=target_device, ) + if hasattr(transformer, 'vace_blocks'): + inferred_rank_for_vace = None + lora_weights_dtype_for_vace = next(iter(transformer.parameters())).dtype # Fallback dtype + + for k_lora_any, v_lora_tensor_any in state_dict.items(): + if k_lora_any.endswith(".lora_A.weight"): + inferred_rank_for_vace = v_lora_tensor_any.shape[0] + lora_weights_dtype_for_vace = v_lora_tensor_any.dtype + break # Found one, good enough for rank and dtype + + if inferred_rank_for_vace is not None: + current_lora_has_bias = any(".lora_B.bias" in k for k in state_dict.keys()) + + for i, vace_block_module_in_model in enumerate(transformer.vace_blocks): + if hasattr(vace_block_module_in_model, 'proj_out'): + + proj_out_linear_layer_in_model = vace_block_module_in_model.proj_out + + vace_lora_A_key = f"vace_blocks.{i}.proj_out.lora_A.weight" + vace_lora_B_key = f"vace_blocks.{i}.proj_out.lora_B.weight" + + if vace_lora_A_key not in state_dict: + state_dict[vace_lora_A_key] = torch.zeros( + (inferred_rank_for_vace, proj_out_linear_layer_in_model.in_features), + device=target_device, dtype=lora_weights_dtype_for_vace + ) + + if vace_lora_B_key not in state_dict: + state_dict[vace_lora_B_key] = torch.zeros( + (proj_out_linear_layer_in_model.out_features, inferred_rank_for_vace), + device=target_device, dtype=lora_weights_dtype_for_vace + ) + + if current_lora_has_bias and proj_out_linear_layer_in_model.bias is not None: + vace_lora_B_bias_key = f"vace_blocks.{i}.proj_out.lora_B.bias" + if vace_lora_B_bias_key not in state_dict: + state_dict[vace_lora_B_bias_key] = torch.zeros_like( + proj_out_linear_layer_in_model.bias, + device=target_device + ) + return state_dict def load_lora_weights(