1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
linoytsaban
2025-05-21 19:27:43 +03:00
parent f91dae9be3
commit f81d4c6a5d

View File

@@ -4810,8 +4810,9 @@ class WanLoraLoaderMixin(LoraBaseMixin):
transformer: torch.nn.Module,
state_dict,
):
# if transformer.config.image_dim is None:
# return state_dict
print("BEFORE", list(state_dict.keys()))
if transformer.config.image_dim is None:
return state_dict
target_device = transformer.device
@@ -4849,7 +4850,20 @@ class WanLoraLoaderMixin(LoraBaseMixin):
ref_lora_B_bias_tensor,
device=target_device,
)
print(state_dict.keys)
return state_dict
@classmethod
def _maybe_expand_t2v_lora_for_vace(
cls,
transformer: torch.nn.Module,
state_dict,
):
if not hasattr(transformer, 'vace_blocks'):
return state_dict
target_device = transformer.device
return state_dict
@@ -4905,6 +4919,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
print("AFTER:", list(state_dict.keys()))
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,