diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 10b6a8f027..8770cfc0bd 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -16,6 +16,7 @@ import os from typing import Callable, Dict, List, Optional, Union import torch +import re from huggingface_hub.utils import validate_hf_hub_args from ..utils import ( @@ -4805,50 +4806,152 @@ class WanLoraLoaderMixin(LoraBaseMixin): return state_dict @classmethod - def _maybe_expand_t2v_lora_for_i2v( - cls, - transformer: torch.nn.Module, - state_dict, - ): - if transformer.config.image_dim is None: - return state_dict + def _modified_maybe_expand_t2v_lora( # Renamed for clarity + # cls, # if it were a classmethod + transformer: torch.nn.Module, + state_dict: Dict[str, torch.Tensor], + lora_filename_for_rank_inference: Optional[str] = None # Optional: for rank hint + ) -> Dict[str, torch.Tensor]: target_device = transformer.device + # Default dtype from transformer, can be refined if LoRA weights have a different one + lora_weights_dtype = next(iter(transformer.parameters())).dtype - if any(k.startswith("transformer.blocks.") for k in state_dict): - num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k}) - is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict) - has_bias = any(".lora_B.bias" in k for k in state_dict) + # --- Infer LoRA rank and potentially refine dtype from existing LoRA weights --- + inferred_rank = None + if state_dict: # If LoRA state_dict already has entries from the T2V LoRA + for k, v_tensor in state_dict.items(): + if k.endswith(".lora_A.weight"): # Standard LoRA weight key part + inferred_rank = v_tensor.shape[0] # rank is the output dim of lora_A + lora_weights_dtype = v_tensor.dtype # Use dtype of existing LoRA weights + break # Found rank and dtype - if is_i2v_lora: - return state_dict + if inferred_rank is None and lora_filename_for_rank_inference: + match = re.search(r"rank(\d+)", lora_filename_for_rank_inference, re.IGNORECASE) + if match: + inferred_rank = int(match.group(1)) + print(f"INFO: Inferred LoRA rank {inferred_rank} from filename for padding.") - for i in range(num_blocks): - for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): - # These keys should exist if the block `i` was part of the T2V LoRA. - ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight" - ref_key_lora_B = f"transformer.blocks.{i}.attn2.to_k.lora_B.weight" + # Determine if the original LoRA format (the T2V part) uses biases for lora_B + lora_format_has_bias = any(".lora_B.bias" in k for k in state_dict.keys()) - if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict: - continue + # --- Part 1: Original I2V expansion for standard transformer.blocks --- + # (Assuming transformer.config and transformer.blocks structure for this part) + if hasattr(transformer, 'config') and hasattr(transformer.config, 'image_dim') and \ + transformer.config.image_dim is not None and hasattr(transformer, 'blocks'): - state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like( - state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device - ) - state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like( - state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device - ) + standard_block_keys_present = any(k.startswith("transformer.blocks.") for k in state_dict) - # If the original LoRA had biases (indicated by has_bias) - # AND the specific reference bias key exists for this block. + if standard_block_keys_present and inferred_rank is not None: + num_blocks_in_lora = 0 + block_indices = set() + for k_lora in state_dict: + if "transformer.blocks." in k_lora: + try: + block_idx_str = k_lora.split("transformer.blocks.")[1].split(".")[0] + if block_idx_str.isdigit(): + block_indices.add(int(block_idx_str)) + except IndexError: + pass + if block_indices: + num_blocks_in_lora = max(block_indices) + 1 - ref_key_lora_B_bias = f"transformer.blocks.{i}.attn2.to_k.lora_B.bias" - if has_bias and ref_key_lora_B_bias in state_dict: - ref_lora_B_bias_tensor = state_dict[ref_key_lora_B_bias] - state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.bias"] = torch.zeros_like( - ref_lora_B_bias_tensor, - device=target_device, - ) + is_i2v_lora_standard_blocks = any( + k.startswith("transformer.blocks.") and "add_k_proj" in k for k in state_dict + ) and any( + k.startswith("transformer.blocks.") and "add_v_proj" in k for k in state_dict + ) + + if not is_i2v_lora_standard_blocks and num_blocks_in_lora > 0: + print(f"INFO: Expanding T2V LoRA for I2V compatibility (standard blocks). Rank: {inferred_rank}") + for i in range(num_blocks_in_lora): + # Check if block 'i' relevant parts are in the T2V LoRA + ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight" + if ref_key_lora_A not in state_dict: + continue # This block's specific part wasn't in the LoRA. + + try: + model_block = transformer.blocks[i] + # Ensure these target layers exist in the model's standard block + if not (hasattr(model_block, 'attn2') and \ + hasattr(model_block.attn2, 'add_k_proj') and \ + hasattr(model_block.attn2, 'add_v_proj')): + continue + add_k_proj_layer = model_block.attn2.add_k_proj + add_v_proj_layer = model_block.attn2.add_v_proj + except (AttributeError, IndexError): + print(f"WARN: Cannot access standard block {i} or its I2V layers for expansion.") + continue + + for proj_name_suffix, model_linear_layer in [("add_k_proj", add_k_proj_layer), + ("add_v_proj", add_v_proj_layer)]: + if not isinstance(model_linear_layer, nn.Linear): continue + + lora_A_key = f"transformer.blocks.{i}.attn2.{proj_name_suffix}.lora_A.weight" + lora_B_key = f"transformer.blocks.{i}.attn2.{proj_name_suffix}.lora_B.weight" + + if lora_A_key not in state_dict: + state_dict[lora_A_key] = torch.zeros( + (inferred_rank, model_linear_layer.in_features), + device=target_device, dtype=lora_weights_dtype + ) + if lora_B_key not in state_dict: + state_dict[lora_B_key] = torch.zeros( + (model_linear_layer.out_features, inferred_rank), + device=target_device, dtype=lora_weights_dtype + ) + + if lora_format_has_bias and model_linear_layer.bias is not None: + lora_B_bias_key = f"transformer.blocks.{i}.attn2.{proj_name_suffix}.lora_B.bias" + if lora_B_bias_key not in state_dict: + state_dict[lora_B_bias_key] = torch.zeros_like( + model_linear_layer.bias, device=target_device, + dtype=model_linear_layer.bias.dtype + ) + elif inferred_rank is None: + print("INFO: LoRA rank not inferred. Skipping I2V expansion for standard blocks.") + # else: not standard_block_keys_present or no I2V capability. + + # --- Part 2: Pad LoRA for WanVACETransformer3DModel vace_blocks.X.proj_out --- + # Dynamically check for WanVACETransformer3DModel availability for isinstance + VACEModelClass = globals().get("WanVACETransformer3DModel") + + if VACEModelClass and isinstance(transformer, VACEModelClass) and hasattr(transformer, 'vace_blocks'): + if inferred_rank is None: + print("WARNING: LoRA rank not determined. Skipping VACE block padding for proj_out.") + else: + print(f"INFO: Transformer is WanVACE. Padding LoRA for vace_blocks.X.proj_out. Rank: {inferred_rank}") + for i, vace_block_module in enumerate(transformer.vace_blocks): + if hasattr(vace_block_module, 'proj_out') and isinstance(vace_block_module.proj_out, nn.Linear): + proj_out_layer = vace_block_module.proj_out + + # Keys for the vace_block's proj_out LoRA layers + # These are the keys PEFT expects in the state_dict *before* adding adapter name context + lora_A_key = f"vace_blocks.{i}.proj_out.lora_A.weight" + lora_B_key = f"vace_blocks.{i}.proj_out.lora_B.weight" + + if lora_A_key not in state_dict: + state_dict[lora_A_key] = torch.zeros( + (inferred_rank, proj_out_layer.in_features), + device=target_device, dtype=lora_weights_dtype + ) + # print(f"Padded: {lora_A_key}") + + if lora_B_key not in state_dict: + state_dict[lora_B_key] = torch.zeros( + (proj_out_layer.out_features, inferred_rank), + device=target_device, dtype=lora_weights_dtype + ) + # print(f"Padded: {lora_B_key}") + + if lora_format_has_bias and proj_out_layer.bias is not None: + lora_B_bias_key = f"vace_blocks.{i}.proj_out.lora_B.bias" + if lora_B_bias_key not in state_dict: + state_dict[lora_B_bias_key] = torch.zeros_like( + proj_out_layer.bias, device=target_device, dtype=proj_out_layer.bias.dtype + ) + # print(f"Padded: {lora_B_bias_key}") + # else: VACE block 'i' might not have proj_out or it's not Linear. return state_dict