mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
support latest few-step wan LoRA. (#12541)
* support latest few-step wan LoRA. * up * up
This commit is contained in:
@@ -1977,14 +1977,34 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
"time_projection.1.diff_b"
|
||||
)
|
||||
|
||||
if any("head.head" in k for k in state_dict):
|
||||
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
|
||||
f"head.head.{lora_down_key}.weight"
|
||||
)
|
||||
converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(f"head.head.{lora_up_key}.weight")
|
||||
if any("head.head" in k for k in original_state_dict):
|
||||
if any(f"head.head.{lora_down_key}.weight" in k for k in state_dict):
|
||||
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
|
||||
f"head.head.{lora_down_key}.weight"
|
||||
)
|
||||
if any(f"head.head.{lora_up_key}.weight" in k for k in state_dict):
|
||||
converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(
|
||||
f"head.head.{lora_up_key}.weight"
|
||||
)
|
||||
if "head.head.diff_b" in original_state_dict:
|
||||
converted_state_dict["proj_out.lora_B.bias"] = original_state_dict.pop("head.head.diff_b")
|
||||
|
||||
# Notes: https://huggingface.co/lightx2v/Wan2.2-Distill-Loras
|
||||
# This is my (sayakpaul) assumption that this particular key belongs to the down matrix.
|
||||
# Since for this particular LoRA, we don't have the corresponding up matrix, I will use
|
||||
# an identity.
|
||||
if any("head.head" in k and k.endswith(".diff") for k in state_dict):
|
||||
if f"head.head.{lora_down_key}.weight" in state_dict:
|
||||
logger.info(
|
||||
f"The state dict seems to be have both `head.head.diff` and `head.head.{lora_down_key}.weight` keys, which is unexpected."
|
||||
)
|
||||
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop("head.head.diff")
|
||||
down_matrix_head = converted_state_dict["proj_out.lora_A.weight"]
|
||||
up_matrix_shape = (down_matrix_head.shape[0], converted_state_dict["proj_out.lora_B.bias"].shape[0])
|
||||
converted_state_dict["proj_out.lora_B.weight"] = torch.eye(
|
||||
*up_matrix_shape, dtype=down_matrix_head.dtype, device=down_matrix_head.device
|
||||
).T
|
||||
|
||||
for text_time in ["text_embedding", "time_embedding"]:
|
||||
if any(text_time in k for k in original_state_dict):
|
||||
for b_n in [0, 2]:
|
||||
|
||||
Reference in New Issue
Block a user