From f81d4c6a5d021859bb99f5469536593c24d035d3 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Wed, 21 May 2025 19:27:43 +0300 Subject: [PATCH] revert --- src/diffusers/loaders/lora_pipeline.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 7c92bf0761..28ef8e63b5 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4810,8 +4810,9 @@ class WanLoraLoaderMixin(LoraBaseMixin): transformer: torch.nn.Module, state_dict, ): - # if transformer.config.image_dim is None: - # return state_dict + print("BEFORE", list(state_dict.keys())) + if transformer.config.image_dim is None: + return state_dict target_device = transformer.device @@ -4849,7 +4850,20 @@ class WanLoraLoaderMixin(LoraBaseMixin): ref_lora_B_bias_tensor, device=target_device, ) - print(state_dict.keys) + + return state_dict + + @classmethod + def _maybe_expand_t2v_lora_for_vace( + cls, + transformer: torch.nn.Module, + state_dict, + ): + + if not hasattr(transformer, 'vace_blocks'): + return state_dict + + target_device = transformer.device return state_dict @@ -4905,6 +4919,7 @@ class WanLoraLoaderMixin(LoraBaseMixin): if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") + print("AFTER:", list(state_dict.keys())) self.load_lora_into_transformer( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,