From c665a0412ce35dff3f90be7fb49cc28a6d1a0041 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 21 May 2025 19:11:44 +0300 Subject: [PATCH] revert --- src/diffusers/loaders/lora_pipeline.py | 58 +++----------------------- 1 file changed, 6 insertions(+), 52 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 96996d46cb..2c29702490 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4810,11 +4810,11 @@ class WanLoraLoaderMixin(LoraBaseMixin): transformer: torch.nn.Module, state_dict, ): - print("wtf 0", hasattr(transformer, 'vace_blocks')) - # if transformer.config.image_dim is None: - # return state_dict + if transformer.config.image_dim is None: + return state_dict target_device = transformer.device + if any(k.startswith("transformer.blocks.") for k in state_dict): num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k}) is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict) @@ -4833,10 +4833,10 @@ class WanLoraLoaderMixin(LoraBaseMixin): continue state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like( - state_dict[ref_key_lora_A], device=target_device # Using original ref_key_lora_A + state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device ) state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like( - state_dict[ref_key_lora_B], device=target_device # Using original ref_key_lora_B + state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device ) # If the original LoRA had biases (indicated by has_bias) @@ -4849,52 +4849,7 @@ class WanLoraLoaderMixin(LoraBaseMixin): ref_lora_B_bias_tensor, device=target_device, ) - - - if hasattr(transformer, 'vace_blocks'): - print(f"{i}, WTF 0") - inferred_rank_for_vace = None - lora_weights_dtype_for_vace = next(iter(transformer.parameters())).dtype # Fallback dtype - - for k_lora_any, v_lora_tensor_any in state_dict.items(): - if k_lora_any.endswith(".lora_A.weight"): - inferred_rank_for_vace = v_lora_tensor_any.shape[0] - lora_weights_dtype_for_vace = v_lora_tensor_any.dtype - break # Found one, good enough for rank and dtype - - if inferred_rank_for_vace is not None: - current_lora_has_bias = any(".lora_B.bias" in k for k in state_dict.keys()) - - for i, vace_block_module_in_model in enumerate(transformer.vace_blocks): - if hasattr(vace_block_module_in_model, 'proj_out'): - - proj_out_linear_layer_in_model = vace_block_module_in_model.proj_out - - vace_lora_A_key = f"vace_blocks.{i}.proj_out.lora_A.weight" - vace_lora_B_key = f"vace_blocks.{i}.proj_out.lora_B.weight" - - if vace_lora_A_key not in state_dict: - print(f"{i}, WTF 1") - state_dict[vace_lora_A_key] = torch.zeros( - (inferred_rank_for_vace, proj_out_linear_layer_in_model.in_features), - device=target_device, dtype=lora_weights_dtype_for_vace - ) - - if vace_lora_B_key not in state_dict: - print(f"{i}, WTF 2") - state_dict[vace_lora_B_key] = torch.zeros( - (proj_out_linear_layer_in_model.out_features, inferred_rank_for_vace), - device=target_device, dtype=lora_weights_dtype_for_vace - ) - - if current_lora_has_bias and proj_out_linear_layer_in_model.bias is not None: - print(f"{i}, WTF 3") - vace_lora_B_bias_key = f"vace_blocks.{i}.proj_out.lora_B.bias" - if vace_lora_B_bias_key not in state_dict: - state_dict[vace_lora_B_bias_key] = torch.zeros_like( - proj_out_linear_layer_in_model.bias, - device=target_device - ) + print(state_dict.keys) return state_dict @@ -4942,7 +4897,6 @@ class WanLoraLoaderMixin(LoraBaseMixin): # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers - print("_maybe_expand_t2v_lora_for_i2v?????????????????") state_dict = self._maybe_expand_t2v_lora_for_i2v( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, state_dict=state_dict,