1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[LoRA] feat: support more Qwen LoRAs from the community. (#12170)

* feat: support more Qwen LoRAs from the community.

* revert unrelated changes.

* Revert "revert unrelated changes."

This reverts commit 82dea555dc.
This commit is contained in:
Sayak Paul
2025-08-18 20:56:28 +05:30
committed by GitHub
parent 5b53f67f06
commit 555b6cc34f
2 changed files with 70 additions and 1 deletions

View File

@@ -2080,6 +2080,74 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
if has_lora_unet:
state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}
def convert_key(key: str) -> str:
prefix = "transformer_blocks"
if "." in key:
base, suffix = key.rsplit(".", 1)
else:
base, suffix = key, ""
start = f"{prefix}_"
rest = base[len(start) :]
if "." in rest:
head, tail = rest.split(".", 1)
tail = "." + tail
else:
head, tail = rest, ""
# Protected n-grams that must keep their internal underscores
protected = {
# pairs
("to", "q"),
("to", "k"),
("to", "v"),
("to", "out"),
("add", "q"),
("add", "k"),
("add", "v"),
("txt", "mlp"),
("img", "mlp"),
("txt", "mod"),
("img", "mod"),
# triplets
("add", "q", "proj"),
("add", "k", "proj"),
("add", "v", "proj"),
("to", "add", "out"),
}
prot_by_len = {}
for ng in protected:
prot_by_len.setdefault(len(ng), set()).add(ng)
parts = head.split("_")
merged = []
i = 0
lengths_desc = sorted(prot_by_len.keys(), reverse=True)
while i < len(parts):
matched = False
for L in lengths_desc:
if i + L <= len(parts) and tuple(parts[i : i + L]) in prot_by_len[L]:
merged.append("_".join(parts[i : i + L]))
i += L
matched = True
break
if not matched:
merged.append(parts[i])
i += 1
head_converted = ".".join(merged)
converted_base = f"{prefix}.{head_converted}{tail}"
return converted_base + (("." + suffix) if suffix else "")
state_dict = {convert_key(k): v for k, v in state_dict.items()}
converted_state_dict = {}
all_keys = list(state_dict.keys())
down_key = ".lora_down.weight"

View File

@@ -6643,7 +6643,8 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
if has_alphas_in_sd:
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
if has_alphas_in_sd or has_lora_unet:
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
out = (state_dict, metadata) if return_lora_metadata else state_dict