1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[LoRA] support more comyui loras for Flux 🚨 (#10985)

* support more comyui loras.

* fix

* fixes

* revert changes in LoRA base.

* no position_embedding

* 🚨 introduce a breaking change to let peft handle module ambiguity

* styling

* remove position embeddings.

* improvements.

* style

* make info instead of NotImplementedError

* Update src/diffusers/loaders/peft.py

Co-authored-by: hlky <hlky@hlky.ac>

* add example.

* robust checks

* updates

---------

Co-authored-by: hlky <hlky@hlky.ac>
This commit is contained in:
Sayak Paul
2025-04-09 09:17:05 +05:30
committed by GitHub
parent f685981ed0
commit 6bfacf0418
4 changed files with 244 additions and 55 deletions

View File

@@ -13,15 +13,22 @@
# limitations under the License.
import re
from typing import List
import torch
from ..utils import is_peft_version, logging
from ..utils import is_peft_version, logging, state_dict_all_zero
logger = logging.get_logger(__name__)
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5):
# 1. get all state_dict_keys
all_keys = list(state_dict.keys())
@@ -313,6 +320,7 @@ def _convert_text_encoder_lora_key(key, lora_name):
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
return diffusers_name
@@ -331,8 +339,7 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
# All credits go to `kohya-ss`.
# are adapted from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
def _convert_kohya_flux_lora_to_diffusers(state_dict):
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
if sds_key + ".lora_down.weight" not in sds_sd:
@@ -341,7 +348,8 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
# scale weight by alpha and dim
rank = down_weight.shape[0]
alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False)
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item() # alpha is scalar
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
@@ -362,7 +370,10 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
sd_lora_rank = down_weight.shape[0]
# scale weight by alpha and dim
alpha = sds_sd.pop(sds_key + ".alpha")
default_alpha = torch.tensor(
sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False
)
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha)
scale = alpha / sd_lora_rank
# calculate scale_down and scale_up
@@ -516,10 +527,103 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
f"transformer.single_transformer_blocks.{i}.norm.linear",
)
# TODO: alphas.
def assign_remaining_weights(assignments, source):
for lora_key in ["lora_A", "lora_B"]:
orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up"
for target_fmt, source_fmt, transform in assignments:
target_key = target_fmt.format(lora_key=lora_key)
source_key = source_fmt.format(orig_lora_key=orig_lora_key)
value = source.pop(source_key)
if transform:
value = transform(value)
ait_sd[target_key] = value
if any("guidance_in" in k for k in sds_sd):
assign_remaining_weights(
[
(
"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight",
"lora_unet_guidance_in_in_layer.{orig_lora_key}.weight",
None,
),
(
"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight",
"lora_unet_guidance_in_out_layer.{orig_lora_key}.weight",
None,
),
],
sds_sd,
)
if any("img_in" in k for k in sds_sd):
assign_remaining_weights(
[
("x_embedder.{lora_key}.weight", "lora_unet_img_in.{orig_lora_key}.weight", None),
],
sds_sd,
)
if any("txt_in" in k for k in sds_sd):
assign_remaining_weights(
[
("context_embedder.{lora_key}.weight", "lora_unet_txt_in.{orig_lora_key}.weight", None),
],
sds_sd,
)
if any("time_in" in k for k in sds_sd):
assign_remaining_weights(
[
(
"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight",
"lora_unet_time_in_in_layer.{orig_lora_key}.weight",
None,
),
(
"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight",
"lora_unet_time_in_out_layer.{orig_lora_key}.weight",
None,
),
],
sds_sd,
)
if any("vector_in" in k for k in sds_sd):
assign_remaining_weights(
[
(
"time_text_embed.text_embedder.linear_1.{lora_key}.weight",
"lora_unet_vector_in_in_layer.{orig_lora_key}.weight",
None,
),
(
"time_text_embed.text_embedder.linear_2.{lora_key}.weight",
"lora_unet_vector_in_out_layer.{orig_lora_key}.weight",
None,
),
],
sds_sd,
)
if any("final_layer" in k for k in sds_sd):
# Notice the swap in processing for "final_layer".
assign_remaining_weights(
[
(
"norm_out.linear.{lora_key}.weight",
"lora_unet_final_layer_adaLN_modulation_1.{orig_lora_key}.weight",
swap_scale_shift,
),
("proj_out.{lora_key}.weight", "lora_unet_final_layer_linear.{orig_lora_key}.weight", None),
],
sds_sd,
)
remaining_keys = list(sds_sd.keys())
te_state_dict = {}
if remaining_keys:
if not all(k.startswith("lora_te") for k in remaining_keys):
if not all(k.startswith(("lora_te", "lora_te1")) for k in remaining_keys):
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
for key in remaining_keys:
if not key.endswith("lora_down.weight"):
@@ -680,10 +784,98 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
if has_peft_state_dict:
state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
return state_dict
# Another weird one.
has_mixture = any(
k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict
)
# ComfyUI.
if not has_mixture:
state_dict = {k.replace("diffusion_model.", "lora_unet_"): v for k, v in state_dict.items()}
state_dict = {k.replace("text_encoders.clip_l.transformer.", "lora_te_"): v for k, v in state_dict.items()}
has_position_embedding = any("position_embedding" in k for k in state_dict)
if has_position_embedding:
zero_status_pe = state_dict_all_zero(state_dict, "position_embedding")
if zero_status_pe:
logger.info(
"The `position_embedding` LoRA params are all zeros which make them ineffective. "
"So, we will purge them out of the curret state dict to make loading possible."
)
else:
logger.info(
"The state_dict has position_embedding LoRA params and we currently do not support them. "
"Open an issue if you need this supported - https://github.com/huggingface/diffusers/issues/new."
)
state_dict = {k: v for k, v in state_dict.items() if "position_embedding" not in k}
has_t5xxl = any(k.startswith("text_encoders.t5xxl.transformer.") for k in state_dict)
if has_t5xxl:
zero_status_t5 = state_dict_all_zero(state_dict, "text_encoders.t5xxl")
if zero_status_t5:
logger.info(
"The `t5xxl` LoRA params are all zeros which make them ineffective. "
"So, we will purge them out of the curret state dict to make loading possible."
)
else:
logger.info(
"T5-xxl keys found in the state dict, which are currently unsupported. We will filter them out."
"Open an issue if this is a problem - https://github.com/huggingface/diffusers/issues/new."
)
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict)
if has_diffb:
zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b")
if zero_status_diff_b:
logger.info(
"The `diff_b` LoRA params are all zeros which make them ineffective. "
"So, we will purge them out of the curret state dict to make loading possible."
)
else:
logger.info(
"`diff_b` keys found in the state dict which are currently unsupported. "
"So, we will filter out those keys. Open an issue if this is a problem - "
"https://github.com/huggingface/diffusers/issues/new."
)
state_dict = {k: v for k, v in state_dict.items() if ".diff_b" not in k}
has_norm_diff = any(".norm" in k and ".diff" in k for k in state_dict)
if has_norm_diff:
zero_status_diff = state_dict_all_zero(state_dict, ".diff")
if zero_status_diff:
logger.info(
"The `diff` LoRA params are all zeros which make them ineffective. "
"So, we will purge them out of the curret state dict to make loading possible."
)
else:
logger.info(
"Normalization diff keys found in the state dict which are currently unsupported. "
"So, we will filter out those keys. Open an issue if this is a problem - "
"https://github.com/huggingface/diffusers/issues/new."
)
state_dict = {k: v for k, v in state_dict.items() if ".norm" not in k and ".diff" not in k}
limit_substrings = ["lora_down", "lora_up"]
if any("alpha" in k for k in state_dict):
limit_substrings.append("alpha")
state_dict = {
_custom_replace(k, limit_substrings): v
for k, v in state_dict.items()
if k.startswith(("lora_unet_", "lora_te_"))
}
if any("text_projection" in k for k in state_dict):
logger.info(
"`text_projection` keys found in the `state_dict` which are unexpected. "
"So, we will filter out those keys. Open an issue if this is a problem - "
"https://github.com/huggingface/diffusers/issues/new."
)
state_dict = {k: v for k, v in state_dict.items() if "text_projection" not in k}
if has_mixture:
return _convert_mixture_state_dict_to_diffusers(state_dict)
@@ -798,6 +990,26 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
return new_state_dict
def _custom_replace(key: str, substrings: List[str]) -> str:
# Replaces the "."s with "_"s upto the `substrings`.
# Example:
# lora_unet.foo.bar.lora_A.weight -> lora_unet_foo_bar.lora_A.weight
pattern = "(" + "|".join(re.escape(sub) for sub in substrings) + ")"
match = re.search(pattern, key)
if match:
start_sub = match.start()
if start_sub > 0 and key[start_sub - 1] == ".":
boundary = start_sub - 1
else:
boundary = start_sub
left = key[:boundary].replace(".", "_")
right = key[boundary:]
return left + right
else:
return key.replace(".", "_")
def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
converted_state_dict = {}
original_state_dict_keys = list(original_state_dict.keys())
@@ -806,11 +1018,6 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
inner_dim = 3072
mlp_ratio = 4.0
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
for lora_key in ["lora_A", "lora_B"]:
## time_text_embed.timestep_embedder <- time_in
converted_state_dict[

View File

@@ -58,23 +58,11 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
}
def _maybe_adjust_config(config):
"""
We may run into some ambiguous configuration values when a model has module names, sharing a common prefix
(`proj_out.weight` and `blocks.transformer.proj_out.weight`, for example) and they have different LoRA ranks. This
method removes the ambiguity by following what is described here:
https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028.
"""
# Track keys that have been explicitly removed to prevent re-adding them.
deleted_keys = set()
def _maybe_raise_error_for_ambiguity(config):
rank_pattern = config["rank_pattern"].copy()
target_modules = config["target_modules"]
original_r = config["r"]
for key in list(rank_pattern.keys()):
key_rank = rank_pattern[key]
# try to detect ambiguity
# `target_modules` can also be a str, in which case this loop would loop
# over the chars of the str. The technically correct way to match LoRA keys
@@ -82,35 +70,12 @@ def _maybe_adjust_config(config):
# But this cuts it for now.
exact_matches = [mod for mod in target_modules if mod == key]
substring_matches = [mod for mod in target_modules if key in mod and mod != key]
ambiguous_key = key
if exact_matches and substring_matches:
# if ambiguous, update the rank associated with the ambiguous key (`proj_out`, for example)
config["r"] = key_rank
# remove the ambiguous key from `rank_pattern` and record it as deleted
del config["rank_pattern"][key]
deleted_keys.add(key)
# For substring matches, add them with the original rank only if they haven't been assigned already
for mod in substring_matches:
if mod not in config["rank_pattern"] and mod not in deleted_keys:
config["rank_pattern"][mod] = original_r
# Update the rest of the target modules with the original rank if not already set and not deleted
for mod in target_modules:
if mod != ambiguous_key and mod not in config["rank_pattern"] and mod not in deleted_keys:
config["rank_pattern"][mod] = original_r
# Handle alphas to deal with cases like:
# https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777
has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"]
if has_different_ranks:
config["lora_alpha"] = config["r"]
alpha_pattern = {}
for module_name, rank in config["rank_pattern"].items():
alpha_pattern[module_name] = rank
config["alpha_pattern"] = alpha_pattern
return config
if is_peft_version("<", "0.14.1"):
raise ValueError(
"There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
)
class PeftAdapterMixin:
@@ -286,16 +251,18 @@ class PeftAdapterMixin:
# Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
# Bias layers in LoRA only have a single dimension
if "lora_B" in key and val.ndim > 1:
# TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged.
rank[key] = val.shape[1]
# Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol.
# We may run into some ambiguous configuration values when a model has module
# names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`,
# for example) and they have different LoRA ranks.
rank[f"^{key}"] = val.shape[1]
if network_alphas is not None and len(network_alphas) >= 1:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
# TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged.
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:

View File

@@ -126,6 +126,7 @@ from .state_dict_utils import (
convert_state_dict_to_kohya,
convert_state_dict_to_peft,
convert_unet_state_dict_to_peft,
state_dict_all_zero,
)
from .typing_utils import _get_detailed_type, _is_valid_type

View File

@@ -17,9 +17,14 @@ State dict utilities: utility methods for converting state dicts easily
import enum
from .import_utils import is_torch_available
from .logging import get_logger
if is_torch_available():
import torch
logger = get_logger(__name__)
@@ -333,3 +338,12 @@ def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight))
return kohya_ss_state_dict
def state_dict_all_zero(state_dict, filter_str=None):
if filter_str is not None:
if isinstance(filter_str, str):
filter_str = [filter_str]
state_dict = {k: v for k, v in state_dict.items() if any(f in k for f in filter_str)}
return all(torch.all(param == 0).item() for param in state_dict.values())