mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
feat: support more Qwen LoRAs from the community.
This commit is contained in:
@@ -489,10 +489,10 @@ else:
|
||||
"PixArtAlphaPipeline",
|
||||
"PixArtSigmaPAGPipeline",
|
||||
"PixArtSigmaPipeline",
|
||||
"QwenImageEditPipeline",
|
||||
"QwenImageImg2ImgPipeline",
|
||||
"QwenImageInpaintPipeline",
|
||||
"QwenImagePipeline",
|
||||
"QwenImageEditPipeline",
|
||||
"ReduxImageEncoder",
|
||||
"SanaControlNetPipeline",
|
||||
"SanaPAGPipeline",
|
||||
|
||||
@@ -2080,6 +2080,74 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
|
||||
|
||||
|
||||
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
|
||||
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
|
||||
if has_lora_unet:
|
||||
state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}
|
||||
|
||||
def convert_key(key: str) -> str:
|
||||
prefix = "transformer_blocks"
|
||||
if "." in key:
|
||||
base, suffix = key.rsplit(".", 1)
|
||||
else:
|
||||
base, suffix = key, ""
|
||||
|
||||
start = f"{prefix}_"
|
||||
rest = base[len(start) :]
|
||||
|
||||
if "." in rest:
|
||||
head, tail = rest.split(".", 1)
|
||||
tail = "." + tail
|
||||
else:
|
||||
head, tail = rest, ""
|
||||
|
||||
# Protected n-grams that must keep their internal underscores
|
||||
protected = {
|
||||
# pairs
|
||||
("to", "q"),
|
||||
("to", "k"),
|
||||
("to", "v"),
|
||||
("to", "out"),
|
||||
("add", "q"),
|
||||
("add", "k"),
|
||||
("add", "v"),
|
||||
("txt", "mlp"),
|
||||
("img", "mlp"),
|
||||
("txt", "mod"),
|
||||
("img", "mod"),
|
||||
# triplets
|
||||
("add", "q", "proj"),
|
||||
("add", "k", "proj"),
|
||||
("add", "v", "proj"),
|
||||
("to", "add", "out"),
|
||||
}
|
||||
|
||||
prot_by_len = {}
|
||||
for ng in protected:
|
||||
prot_by_len.setdefault(len(ng), set()).add(ng)
|
||||
|
||||
parts = head.split("_")
|
||||
merged = []
|
||||
i = 0
|
||||
lengths_desc = sorted(prot_by_len.keys(), reverse=True)
|
||||
|
||||
while i < len(parts):
|
||||
matched = False
|
||||
for L in lengths_desc:
|
||||
if i + L <= len(parts) and tuple(parts[i : i + L]) in prot_by_len[L]:
|
||||
merged.append("_".join(parts[i : i + L]))
|
||||
i += L
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
merged.append(parts[i])
|
||||
i += 1
|
||||
|
||||
head_converted = ".".join(merged)
|
||||
converted_base = f"{prefix}.{head_converted}{tail}"
|
||||
return converted_base + (("." + suffix) if suffix else "")
|
||||
|
||||
state_dict = {convert_key(k): v for k, v in state_dict.items()}
|
||||
|
||||
converted_state_dict = {}
|
||||
all_keys = list(state_dict.keys())
|
||||
down_key = ".lora_down.weight"
|
||||
|
||||
@@ -6643,7 +6643,8 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
|
||||
if has_alphas_in_sd:
|
||||
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
|
||||
if has_alphas_in_sd or has_lora_unet:
|
||||
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
|
||||
|
||||
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
||||
|
||||
@@ -24,9 +24,9 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["modeling_qwenimage"] = ["ReduxImageEncoder"]
|
||||
_import_structure["pipeline_qwenimage"] = ["QwenImagePipeline"]
|
||||
_import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"]
|
||||
_import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"]
|
||||
_import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"]
|
||||
_import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user