mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
revert
This commit is contained in:
@@ -4810,11 +4810,11 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
transformer: torch.nn.Module,
|
||||
state_dict,
|
||||
):
|
||||
print("wtf 0", hasattr(transformer, 'vace_blocks'))
|
||||
# if transformer.config.image_dim is None:
|
||||
# return state_dict
|
||||
if transformer.config.image_dim is None:
|
||||
return state_dict
|
||||
|
||||
target_device = transformer.device
|
||||
|
||||
if any(k.startswith("transformer.blocks.") for k in state_dict):
|
||||
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k})
|
||||
is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
|
||||
@@ -4833,10 +4833,10 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
continue
|
||||
|
||||
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
|
||||
state_dict[ref_key_lora_A], device=target_device # Using original ref_key_lora_A
|
||||
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device
|
||||
)
|
||||
state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
|
||||
state_dict[ref_key_lora_B], device=target_device # Using original ref_key_lora_B
|
||||
state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device
|
||||
)
|
||||
|
||||
# If the original LoRA had biases (indicated by has_bias)
|
||||
@@ -4849,52 +4849,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
ref_lora_B_bias_tensor,
|
||||
device=target_device,
|
||||
)
|
||||
|
||||
|
||||
if hasattr(transformer, 'vace_blocks'):
|
||||
print(f"{i}, WTF 0")
|
||||
inferred_rank_for_vace = None
|
||||
lora_weights_dtype_for_vace = next(iter(transformer.parameters())).dtype # Fallback dtype
|
||||
|
||||
for k_lora_any, v_lora_tensor_any in state_dict.items():
|
||||
if k_lora_any.endswith(".lora_A.weight"):
|
||||
inferred_rank_for_vace = v_lora_tensor_any.shape[0]
|
||||
lora_weights_dtype_for_vace = v_lora_tensor_any.dtype
|
||||
break # Found one, good enough for rank and dtype
|
||||
|
||||
if inferred_rank_for_vace is not None:
|
||||
current_lora_has_bias = any(".lora_B.bias" in k for k in state_dict.keys())
|
||||
|
||||
for i, vace_block_module_in_model in enumerate(transformer.vace_blocks):
|
||||
if hasattr(vace_block_module_in_model, 'proj_out'):
|
||||
|
||||
proj_out_linear_layer_in_model = vace_block_module_in_model.proj_out
|
||||
|
||||
vace_lora_A_key = f"vace_blocks.{i}.proj_out.lora_A.weight"
|
||||
vace_lora_B_key = f"vace_blocks.{i}.proj_out.lora_B.weight"
|
||||
|
||||
if vace_lora_A_key not in state_dict:
|
||||
print(f"{i}, WTF 1")
|
||||
state_dict[vace_lora_A_key] = torch.zeros(
|
||||
(inferred_rank_for_vace, proj_out_linear_layer_in_model.in_features),
|
||||
device=target_device, dtype=lora_weights_dtype_for_vace
|
||||
)
|
||||
|
||||
if vace_lora_B_key not in state_dict:
|
||||
print(f"{i}, WTF 2")
|
||||
state_dict[vace_lora_B_key] = torch.zeros(
|
||||
(proj_out_linear_layer_in_model.out_features, inferred_rank_for_vace),
|
||||
device=target_device, dtype=lora_weights_dtype_for_vace
|
||||
)
|
||||
|
||||
if current_lora_has_bias and proj_out_linear_layer_in_model.bias is not None:
|
||||
print(f"{i}, WTF 3")
|
||||
vace_lora_B_bias_key = f"vace_blocks.{i}.proj_out.lora_B.bias"
|
||||
if vace_lora_B_bias_key not in state_dict:
|
||||
state_dict[vace_lora_B_bias_key] = torch.zeros_like(
|
||||
proj_out_linear_layer_in_model.bias,
|
||||
device=target_device
|
||||
)
|
||||
print(state_dict.keys)
|
||||
|
||||
return state_dict
|
||||
|
||||
@@ -4942,7 +4897,6 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
# convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
|
||||
print("_maybe_expand_t2v_lora_for_i2v?????????????????")
|
||||
state_dict = self._maybe_expand_t2v_lora_for_i2v(
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
state_dict=state_dict,
|
||||
|
||||
Reference in New Issue
Block a user