1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
linoytsaban
2025-05-21 19:47:02 +03:00
parent f81d4c6a5d
commit b8a371ef69

View File

@@ -4859,11 +4859,57 @@ class WanLoraLoaderMixin(LoraBaseMixin):
transformer: torch.nn.Module,
state_dict,
):
if not hasattr(transformer, 'vace_blocks'):
return state_dict
target_device = transformer.device
if hasattr(transformer, 'vace_blocks'):
inferred_rank_for_vace = None
lora_weights_dtype_for_vace = next(iter(transformer.parameters())).dtype # Fallback dtype
for k_lora_any, v_lora_tensor_any in state_dict.items():
if k_lora_any.endswith(".lora_A.weight"):
inferred_rank_for_vace = v_lora_tensor_any.shape[0]
lora_weights_dtype_for_vace = v_lora_tensor_any.dtype
break # Found one, good enough for rank and dtype
if inferred_rank_for_vace is not None:
# Determine if the LoRA format (as potentially modified by I2V expansion) includes bias
# This re-checks 'has_bias' based on the *current* state_dict.
current_lora_has_bias = any(".lora_B.bias" in k for k in state_dict.keys())
for i, vace_block_module_in_model in enumerate(transformer.vace_blocks):
# Specifically target proj_out as per the error message
if hasattr(vace_block_module_in_model, 'proj_out') and \
isinstance(vace_block_module_in_model.proj_out, nn.Linear):
proj_out_linear_layer_in_model = vace_block_module_in_model.proj_out
vace_lora_A_key = f"vace_blocks.{i}.proj_out.lora_A.weight"
vace_lora_B_key = f"vace_blocks.{i}.proj_out.lora_B.weight"
if vace_lora_A_key not in state_dict:
state_dict[vace_lora_A_key] = torch.zeros(
(inferred_rank_for_vace, proj_out_linear_layer_in_model.in_features),
device=target_device, dtype=lora_weights_dtype_for_vace
)
if vace_lora_B_key not in state_dict:
state_dict[vace_lora_B_key] = torch.zeros(
(proj_out_linear_layer_in_model.out_features, inferred_rank_for_vace),
device=target_device, dtype=lora_weights_dtype_for_vace
)
# Use 'current_lora_has_bias' to decide on padding bias for VACE blocks
if current_lora_has_bias and proj_out_linear_layer_in_model.bias is not None:
vace_lora_B_bias_key = f"vace_blocks.{i}.proj_out.lora_B.bias"
if vace_lora_B_bias_key not in state_dict:
state_dict[vace_lora_B_bias_key] = torch.zeros_like(
proj_out_linear_layer_in_model.bias, # Shape from model's bias
device=target_device # Dtype from model's bias implicitly by zeros_like
)
print("AFTER 2:", list(state_dict.keys()))
return state_dict
return state_dict
@@ -4915,6 +4961,10 @@ class WanLoraLoaderMixin(LoraBaseMixin):
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
state_dict=state_dict,
)
state_dict = self._maybe_expand_t2v_lora_for_vace(
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
state_dict=state_dict,
)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")