mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
revert
This commit is contained in:
@@ -4810,8 +4810,9 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
transformer: torch.nn.Module,
|
||||
state_dict,
|
||||
):
|
||||
# if transformer.config.image_dim is None:
|
||||
# return state_dict
|
||||
print("BEFORE", list(state_dict.keys()))
|
||||
if transformer.config.image_dim is None:
|
||||
return state_dict
|
||||
|
||||
target_device = transformer.device
|
||||
|
||||
@@ -4849,7 +4850,20 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
ref_lora_B_bias_tensor,
|
||||
device=target_device,
|
||||
)
|
||||
print(state_dict.keys)
|
||||
|
||||
return state_dict
|
||||
|
||||
@classmethod
|
||||
def _maybe_expand_t2v_lora_for_vace(
|
||||
cls,
|
||||
transformer: torch.nn.Module,
|
||||
state_dict,
|
||||
):
|
||||
|
||||
if not hasattr(transformer, 'vace_blocks'):
|
||||
return state_dict
|
||||
|
||||
target_device = transformer.device
|
||||
|
||||
return state_dict
|
||||
|
||||
@@ -4905,6 +4919,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
print("AFTER:", list(state_dict.keys()))
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
|
||||
Reference in New Issue
Block a user