mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] support more comyui loras for Flux 🚨 (#10985)
* support more comyui loras. * fix * fixes * revert changes in LoRA base. * no position_embedding * 🚨 introduce a breaking change to let peft handle module ambiguity * styling * remove position embeddings. * improvements. * style * make info instead of NotImplementedError * Update src/diffusers/loaders/peft.py Co-authored-by: hlky <hlky@hlky.ac> * add example. * robust checks * updates --------- Co-authored-by: hlky <hlky@hlky.ac>
This commit is contained in:
@@ -13,15 +13,22 @@
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import is_peft_version, logging
|
||||
from ..utils import is_peft_version, logging, state_dict_all_zero
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def swap_scale_shift(weight):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
|
||||
def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5):
|
||||
# 1. get all state_dict_keys
|
||||
all_keys = list(state_dict.keys())
|
||||
@@ -313,6 +320,7 @@ def _convert_text_encoder_lora_key(key, lora_name):
|
||||
# Be aware that this is the new diffusers convention and the rest of the code might
|
||||
# not utilize it yet.
|
||||
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
|
||||
|
||||
return diffusers_name
|
||||
|
||||
|
||||
@@ -331,8 +339,7 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
|
||||
|
||||
|
||||
# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
|
||||
# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
|
||||
# All credits go to `kohya-ss`.
|
||||
# are adapted from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
|
||||
def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
|
||||
if sds_key + ".lora_down.weight" not in sds_sd:
|
||||
@@ -341,7 +348,8 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
|
||||
# scale weight by alpha and dim
|
||||
rank = down_weight.shape[0]
|
||||
alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
|
||||
default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False)
|
||||
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item() # alpha is scalar
|
||||
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
|
||||
# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
|
||||
@@ -362,7 +370,10 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
sd_lora_rank = down_weight.shape[0]
|
||||
|
||||
# scale weight by alpha and dim
|
||||
alpha = sds_sd.pop(sds_key + ".alpha")
|
||||
default_alpha = torch.tensor(
|
||||
sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False
|
||||
)
|
||||
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha)
|
||||
scale = alpha / sd_lora_rank
|
||||
|
||||
# calculate scale_down and scale_up
|
||||
@@ -516,10 +527,103 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
f"transformer.single_transformer_blocks.{i}.norm.linear",
|
||||
)
|
||||
|
||||
# TODO: alphas.
|
||||
def assign_remaining_weights(assignments, source):
|
||||
for lora_key in ["lora_A", "lora_B"]:
|
||||
orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up"
|
||||
for target_fmt, source_fmt, transform in assignments:
|
||||
target_key = target_fmt.format(lora_key=lora_key)
|
||||
source_key = source_fmt.format(orig_lora_key=orig_lora_key)
|
||||
value = source.pop(source_key)
|
||||
if transform:
|
||||
value = transform(value)
|
||||
ait_sd[target_key] = value
|
||||
|
||||
if any("guidance_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
(
|
||||
"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight",
|
||||
"lora_unet_guidance_in_in_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
(
|
||||
"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight",
|
||||
"lora_unet_guidance_in_out_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
],
|
||||
sds_sd,
|
||||
)
|
||||
|
||||
if any("img_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
("x_embedder.{lora_key}.weight", "lora_unet_img_in.{orig_lora_key}.weight", None),
|
||||
],
|
||||
sds_sd,
|
||||
)
|
||||
|
||||
if any("txt_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
("context_embedder.{lora_key}.weight", "lora_unet_txt_in.{orig_lora_key}.weight", None),
|
||||
],
|
||||
sds_sd,
|
||||
)
|
||||
|
||||
if any("time_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
(
|
||||
"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight",
|
||||
"lora_unet_time_in_in_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
(
|
||||
"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight",
|
||||
"lora_unet_time_in_out_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
],
|
||||
sds_sd,
|
||||
)
|
||||
|
||||
if any("vector_in" in k for k in sds_sd):
|
||||
assign_remaining_weights(
|
||||
[
|
||||
(
|
||||
"time_text_embed.text_embedder.linear_1.{lora_key}.weight",
|
||||
"lora_unet_vector_in_in_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
(
|
||||
"time_text_embed.text_embedder.linear_2.{lora_key}.weight",
|
||||
"lora_unet_vector_in_out_layer.{orig_lora_key}.weight",
|
||||
None,
|
||||
),
|
||||
],
|
||||
sds_sd,
|
||||
)
|
||||
|
||||
if any("final_layer" in k for k in sds_sd):
|
||||
# Notice the swap in processing for "final_layer".
|
||||
assign_remaining_weights(
|
||||
[
|
||||
(
|
||||
"norm_out.linear.{lora_key}.weight",
|
||||
"lora_unet_final_layer_adaLN_modulation_1.{orig_lora_key}.weight",
|
||||
swap_scale_shift,
|
||||
),
|
||||
("proj_out.{lora_key}.weight", "lora_unet_final_layer_linear.{orig_lora_key}.weight", None),
|
||||
],
|
||||
sds_sd,
|
||||
)
|
||||
|
||||
remaining_keys = list(sds_sd.keys())
|
||||
te_state_dict = {}
|
||||
if remaining_keys:
|
||||
if not all(k.startswith("lora_te") for k in remaining_keys):
|
||||
if not all(k.startswith(("lora_te", "lora_te1")) for k in remaining_keys):
|
||||
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
|
||||
for key in remaining_keys:
|
||||
if not key.endswith("lora_down.weight"):
|
||||
@@ -680,10 +784,98 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
if has_peft_state_dict:
|
||||
state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
|
||||
return state_dict
|
||||
|
||||
# Another weird one.
|
||||
has_mixture = any(
|
||||
k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict
|
||||
)
|
||||
|
||||
# ComfyUI.
|
||||
if not has_mixture:
|
||||
state_dict = {k.replace("diffusion_model.", "lora_unet_"): v for k, v in state_dict.items()}
|
||||
state_dict = {k.replace("text_encoders.clip_l.transformer.", "lora_te_"): v for k, v in state_dict.items()}
|
||||
|
||||
has_position_embedding = any("position_embedding" in k for k in state_dict)
|
||||
if has_position_embedding:
|
||||
zero_status_pe = state_dict_all_zero(state_dict, "position_embedding")
|
||||
if zero_status_pe:
|
||||
logger.info(
|
||||
"The `position_embedding` LoRA params are all zeros which make them ineffective. "
|
||||
"So, we will purge them out of the curret state dict to make loading possible."
|
||||
)
|
||||
|
||||
else:
|
||||
logger.info(
|
||||
"The state_dict has position_embedding LoRA params and we currently do not support them. "
|
||||
"Open an issue if you need this supported - https://github.com/huggingface/diffusers/issues/new."
|
||||
)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "position_embedding" not in k}
|
||||
|
||||
has_t5xxl = any(k.startswith("text_encoders.t5xxl.transformer.") for k in state_dict)
|
||||
if has_t5xxl:
|
||||
zero_status_t5 = state_dict_all_zero(state_dict, "text_encoders.t5xxl")
|
||||
if zero_status_t5:
|
||||
logger.info(
|
||||
"The `t5xxl` LoRA params are all zeros which make them ineffective. "
|
||||
"So, we will purge them out of the curret state dict to make loading possible."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"T5-xxl keys found in the state dict, which are currently unsupported. We will filter them out."
|
||||
"Open an issue if this is a problem - https://github.com/huggingface/diffusers/issues/new."
|
||||
)
|
||||
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
|
||||
|
||||
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict)
|
||||
if has_diffb:
|
||||
zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b")
|
||||
if zero_status_diff_b:
|
||||
logger.info(
|
||||
"The `diff_b` LoRA params are all zeros which make them ineffective. "
|
||||
"So, we will purge them out of the curret state dict to make loading possible."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"`diff_b` keys found in the state dict which are currently unsupported. "
|
||||
"So, we will filter out those keys. Open an issue if this is a problem - "
|
||||
"https://github.com/huggingface/diffusers/issues/new."
|
||||
)
|
||||
state_dict = {k: v for k, v in state_dict.items() if ".diff_b" not in k}
|
||||
|
||||
has_norm_diff = any(".norm" in k and ".diff" in k for k in state_dict)
|
||||
if has_norm_diff:
|
||||
zero_status_diff = state_dict_all_zero(state_dict, ".diff")
|
||||
if zero_status_diff:
|
||||
logger.info(
|
||||
"The `diff` LoRA params are all zeros which make them ineffective. "
|
||||
"So, we will purge them out of the curret state dict to make loading possible."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Normalization diff keys found in the state dict which are currently unsupported. "
|
||||
"So, we will filter out those keys. Open an issue if this is a problem - "
|
||||
"https://github.com/huggingface/diffusers/issues/new."
|
||||
)
|
||||
state_dict = {k: v for k, v in state_dict.items() if ".norm" not in k and ".diff" not in k}
|
||||
|
||||
limit_substrings = ["lora_down", "lora_up"]
|
||||
if any("alpha" in k for k in state_dict):
|
||||
limit_substrings.append("alpha")
|
||||
|
||||
state_dict = {
|
||||
_custom_replace(k, limit_substrings): v
|
||||
for k, v in state_dict.items()
|
||||
if k.startswith(("lora_unet_", "lora_te_"))
|
||||
}
|
||||
|
||||
if any("text_projection" in k for k in state_dict):
|
||||
logger.info(
|
||||
"`text_projection` keys found in the `state_dict` which are unexpected. "
|
||||
"So, we will filter out those keys. Open an issue if this is a problem - "
|
||||
"https://github.com/huggingface/diffusers/issues/new."
|
||||
)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "text_projection" not in k}
|
||||
|
||||
if has_mixture:
|
||||
return _convert_mixture_state_dict_to_diffusers(state_dict)
|
||||
|
||||
@@ -798,6 +990,26 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def _custom_replace(key: str, substrings: List[str]) -> str:
|
||||
# Replaces the "."s with "_"s upto the `substrings`.
|
||||
# Example:
|
||||
# lora_unet.foo.bar.lora_A.weight -> lora_unet_foo_bar.lora_A.weight
|
||||
pattern = "(" + "|".join(re.escape(sub) for sub in substrings) + ")"
|
||||
|
||||
match = re.search(pattern, key)
|
||||
if match:
|
||||
start_sub = match.start()
|
||||
if start_sub > 0 and key[start_sub - 1] == ".":
|
||||
boundary = start_sub - 1
|
||||
else:
|
||||
boundary = start_sub
|
||||
left = key[:boundary].replace(".", "_")
|
||||
right = key[boundary:]
|
||||
return left + right
|
||||
else:
|
||||
return key.replace(".", "_")
|
||||
|
||||
|
||||
def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
||||
converted_state_dict = {}
|
||||
original_state_dict_keys = list(original_state_dict.keys())
|
||||
@@ -806,11 +1018,6 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
||||
inner_dim = 3072
|
||||
mlp_ratio = 4.0
|
||||
|
||||
def swap_scale_shift(weight):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
for lora_key in ["lora_A", "lora_B"]:
|
||||
## time_text_embed.timestep_embedder <- time_in
|
||||
converted_state_dict[
|
||||
|
||||
@@ -58,23 +58,11 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
def _maybe_adjust_config(config):
|
||||
"""
|
||||
We may run into some ambiguous configuration values when a model has module names, sharing a common prefix
|
||||
(`proj_out.weight` and `blocks.transformer.proj_out.weight`, for example) and they have different LoRA ranks. This
|
||||
method removes the ambiguity by following what is described here:
|
||||
https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028.
|
||||
"""
|
||||
# Track keys that have been explicitly removed to prevent re-adding them.
|
||||
deleted_keys = set()
|
||||
|
||||
def _maybe_raise_error_for_ambiguity(config):
|
||||
rank_pattern = config["rank_pattern"].copy()
|
||||
target_modules = config["target_modules"]
|
||||
original_r = config["r"]
|
||||
|
||||
for key in list(rank_pattern.keys()):
|
||||
key_rank = rank_pattern[key]
|
||||
|
||||
# try to detect ambiguity
|
||||
# `target_modules` can also be a str, in which case this loop would loop
|
||||
# over the chars of the str. The technically correct way to match LoRA keys
|
||||
@@ -82,35 +70,12 @@ def _maybe_adjust_config(config):
|
||||
# But this cuts it for now.
|
||||
exact_matches = [mod for mod in target_modules if mod == key]
|
||||
substring_matches = [mod for mod in target_modules if key in mod and mod != key]
|
||||
ambiguous_key = key
|
||||
|
||||
if exact_matches and substring_matches:
|
||||
# if ambiguous, update the rank associated with the ambiguous key (`proj_out`, for example)
|
||||
config["r"] = key_rank
|
||||
# remove the ambiguous key from `rank_pattern` and record it as deleted
|
||||
del config["rank_pattern"][key]
|
||||
deleted_keys.add(key)
|
||||
# For substring matches, add them with the original rank only if they haven't been assigned already
|
||||
for mod in substring_matches:
|
||||
if mod not in config["rank_pattern"] and mod not in deleted_keys:
|
||||
config["rank_pattern"][mod] = original_r
|
||||
|
||||
# Update the rest of the target modules with the original rank if not already set and not deleted
|
||||
for mod in target_modules:
|
||||
if mod != ambiguous_key and mod not in config["rank_pattern"] and mod not in deleted_keys:
|
||||
config["rank_pattern"][mod] = original_r
|
||||
|
||||
# Handle alphas to deal with cases like:
|
||||
# https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777
|
||||
has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"]
|
||||
if has_different_ranks:
|
||||
config["lora_alpha"] = config["r"]
|
||||
alpha_pattern = {}
|
||||
for module_name, rank in config["rank_pattern"].items():
|
||||
alpha_pattern[module_name] = rank
|
||||
config["alpha_pattern"] = alpha_pattern
|
||||
|
||||
return config
|
||||
if is_peft_version("<", "0.14.1"):
|
||||
raise ValueError(
|
||||
"There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
|
||||
)
|
||||
|
||||
|
||||
class PeftAdapterMixin:
|
||||
@@ -286,16 +251,18 @@ class PeftAdapterMixin:
|
||||
# Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
|
||||
# Bias layers in LoRA only have a single dimension
|
||||
if "lora_B" in key and val.ndim > 1:
|
||||
# TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged.
|
||||
rank[key] = val.shape[1]
|
||||
# Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol.
|
||||
# We may run into some ambiguous configuration values when a model has module
|
||||
# names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`,
|
||||
# for example) and they have different LoRA ranks.
|
||||
rank[f"^{key}"] = val.shape[1]
|
||||
|
||||
if network_alphas is not None and len(network_alphas) >= 1:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
|
||||
# TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged.
|
||||
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
|
||||
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
|
||||
@@ -126,6 +126,7 @@ from .state_dict_utils import (
|
||||
convert_state_dict_to_kohya,
|
||||
convert_state_dict_to_peft,
|
||||
convert_unet_state_dict_to_peft,
|
||||
state_dict_all_zero,
|
||||
)
|
||||
from .typing_utils import _get_detailed_type, _is_valid_type
|
||||
|
||||
|
||||
@@ -17,9 +17,14 @@ State dict utilities: utility methods for converting state dicts easily
|
||||
|
||||
import enum
|
||||
|
||||
from .import_utils import is_torch_available
|
||||
from .logging import get_logger
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -333,3 +338,12 @@ def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
|
||||
kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight))
|
||||
|
||||
return kohya_ss_state_dict
|
||||
|
||||
|
||||
def state_dict_all_zero(state_dict, filter_str=None):
|
||||
if filter_str is not None:
|
||||
if isinstance(filter_str, str):
|
||||
filter_str = [filter_str]
|
||||
state_dict = {k: v for k, v in state_dict.items() if any(f in k for f in filter_str)}
|
||||
|
||||
return all(torch.all(param == 0).item() for param in state_dict.values())
|
||||
|
||||
Reference in New Issue
Block a user