diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 7c92bf0761..28ef8e63b5 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4810,8 +4810,9 @@ class WanLoraLoaderMixin(LoraBaseMixin): transformer: torch.nn.Module, state_dict, ): - # if transformer.config.image_dim is None: - # return state_dict + print("BEFORE", list(state_dict.keys())) + if transformer.config.image_dim is None: + return state_dict target_device = transformer.device @@ -4849,7 +4850,20 @@ class WanLoraLoaderMixin(LoraBaseMixin): ref_lora_B_bias_tensor, device=target_device, ) - print(state_dict.keys) + + return state_dict + + @classmethod + def _maybe_expand_t2v_lora_for_vace( + cls, + transformer: torch.nn.Module, + state_dict, + ): + + if not hasattr(transformer, 'vace_blocks'): + return state_dict + + target_device = transformer.device return state_dict @@ -4905,6 +4919,7 @@ class WanLoraLoaderMixin(LoraBaseMixin): if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") + print("AFTER:", list(state_dict.keys())) self.load_lora_into_transformer( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,