diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 20fcb61f3b..5ec16ff299 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -13,15 +13,22 @@ # limitations under the License. import re +from typing import List import torch -from ..utils import is_peft_version, logging +from ..utils import is_peft_version, logging, state_dict_all_zero logger = logging.get_logger(__name__) +def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5): # 1. get all state_dict_keys all_keys = list(state_dict.keys()) @@ -313,6 +320,7 @@ def _convert_text_encoder_lora_key(key, lora_name): # Be aware that this is the new diffusers convention and the rest of the code might # not utilize it yet. diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") + return diffusers_name @@ -331,8 +339,7 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha): # The utilities under `_convert_kohya_flux_lora_to_diffusers()` -# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py -# All credits go to `kohya-ss`. +# are adapted from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py def _convert_kohya_flux_lora_to_diffusers(state_dict): def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key): if sds_key + ".lora_down.weight" not in sds_sd: @@ -341,7 +348,8 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict): # scale weight by alpha and dim rank = down_weight.shape[0] - alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar + default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False) + alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item() # alpha is scalar scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here # calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2 @@ -362,7 +370,10 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict): sd_lora_rank = down_weight.shape[0] # scale weight by alpha and dim - alpha = sds_sd.pop(sds_key + ".alpha") + default_alpha = torch.tensor( + sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False + ) + alpha = sds_sd.pop(sds_key + ".alpha", default_alpha) scale = alpha / sd_lora_rank # calculate scale_down and scale_up @@ -516,10 +527,103 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict): f"transformer.single_transformer_blocks.{i}.norm.linear", ) + # TODO: alphas. + def assign_remaining_weights(assignments, source): + for lora_key in ["lora_A", "lora_B"]: + orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up" + for target_fmt, source_fmt, transform in assignments: + target_key = target_fmt.format(lora_key=lora_key) + source_key = source_fmt.format(orig_lora_key=orig_lora_key) + value = source.pop(source_key) + if transform: + value = transform(value) + ait_sd[target_key] = value + + if any("guidance_in" in k for k in sds_sd): + assign_remaining_weights( + [ + ( + "time_text_embed.guidance_embedder.linear_1.{lora_key}.weight", + "lora_unet_guidance_in_in_layer.{orig_lora_key}.weight", + None, + ), + ( + "time_text_embed.guidance_embedder.linear_2.{lora_key}.weight", + "lora_unet_guidance_in_out_layer.{orig_lora_key}.weight", + None, + ), + ], + sds_sd, + ) + + if any("img_in" in k for k in sds_sd): + assign_remaining_weights( + [ + ("x_embedder.{lora_key}.weight", "lora_unet_img_in.{orig_lora_key}.weight", None), + ], + sds_sd, + ) + + if any("txt_in" in k for k in sds_sd): + assign_remaining_weights( + [ + ("context_embedder.{lora_key}.weight", "lora_unet_txt_in.{orig_lora_key}.weight", None), + ], + sds_sd, + ) + + if any("time_in" in k for k in sds_sd): + assign_remaining_weights( + [ + ( + "time_text_embed.timestep_embedder.linear_1.{lora_key}.weight", + "lora_unet_time_in_in_layer.{orig_lora_key}.weight", + None, + ), + ( + "time_text_embed.timestep_embedder.linear_2.{lora_key}.weight", + "lora_unet_time_in_out_layer.{orig_lora_key}.weight", + None, + ), + ], + sds_sd, + ) + + if any("vector_in" in k for k in sds_sd): + assign_remaining_weights( + [ + ( + "time_text_embed.text_embedder.linear_1.{lora_key}.weight", + "lora_unet_vector_in_in_layer.{orig_lora_key}.weight", + None, + ), + ( + "time_text_embed.text_embedder.linear_2.{lora_key}.weight", + "lora_unet_vector_in_out_layer.{orig_lora_key}.weight", + None, + ), + ], + sds_sd, + ) + + if any("final_layer" in k for k in sds_sd): + # Notice the swap in processing for "final_layer". + assign_remaining_weights( + [ + ( + "norm_out.linear.{lora_key}.weight", + "lora_unet_final_layer_adaLN_modulation_1.{orig_lora_key}.weight", + swap_scale_shift, + ), + ("proj_out.{lora_key}.weight", "lora_unet_final_layer_linear.{orig_lora_key}.weight", None), + ], + sds_sd, + ) + remaining_keys = list(sds_sd.keys()) te_state_dict = {} if remaining_keys: - if not all(k.startswith("lora_te") for k in remaining_keys): + if not all(k.startswith(("lora_te", "lora_te1")) for k in remaining_keys): raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}") for key in remaining_keys: if not key.endswith("lora_down.weight"): @@ -680,10 +784,98 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict): if has_peft_state_dict: state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")} return state_dict + # Another weird one. has_mixture = any( k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict ) + + # ComfyUI. + if not has_mixture: + state_dict = {k.replace("diffusion_model.", "lora_unet_"): v for k, v in state_dict.items()} + state_dict = {k.replace("text_encoders.clip_l.transformer.", "lora_te_"): v for k, v in state_dict.items()} + + has_position_embedding = any("position_embedding" in k for k in state_dict) + if has_position_embedding: + zero_status_pe = state_dict_all_zero(state_dict, "position_embedding") + if zero_status_pe: + logger.info( + "The `position_embedding` LoRA params are all zeros which make them ineffective. " + "So, we will purge them out of the curret state dict to make loading possible." + ) + + else: + logger.info( + "The state_dict has position_embedding LoRA params and we currently do not support them. " + "Open an issue if you need this supported - https://github.com/huggingface/diffusers/issues/new." + ) + state_dict = {k: v for k, v in state_dict.items() if "position_embedding" not in k} + + has_t5xxl = any(k.startswith("text_encoders.t5xxl.transformer.") for k in state_dict) + if has_t5xxl: + zero_status_t5 = state_dict_all_zero(state_dict, "text_encoders.t5xxl") + if zero_status_t5: + logger.info( + "The `t5xxl` LoRA params are all zeros which make them ineffective. " + "So, we will purge them out of the curret state dict to make loading possible." + ) + else: + logger.info( + "T5-xxl keys found in the state dict, which are currently unsupported. We will filter them out." + "Open an issue if this is a problem - https://github.com/huggingface/diffusers/issues/new." + ) + state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")} + + has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict) + if has_diffb: + zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b") + if zero_status_diff_b: + logger.info( + "The `diff_b` LoRA params are all zeros which make them ineffective. " + "So, we will purge them out of the curret state dict to make loading possible." + ) + else: + logger.info( + "`diff_b` keys found in the state dict which are currently unsupported. " + "So, we will filter out those keys. Open an issue if this is a problem - " + "https://github.com/huggingface/diffusers/issues/new." + ) + state_dict = {k: v for k, v in state_dict.items() if ".diff_b" not in k} + + has_norm_diff = any(".norm" in k and ".diff" in k for k in state_dict) + if has_norm_diff: + zero_status_diff = state_dict_all_zero(state_dict, ".diff") + if zero_status_diff: + logger.info( + "The `diff` LoRA params are all zeros which make them ineffective. " + "So, we will purge them out of the curret state dict to make loading possible." + ) + else: + logger.info( + "Normalization diff keys found in the state dict which are currently unsupported. " + "So, we will filter out those keys. Open an issue if this is a problem - " + "https://github.com/huggingface/diffusers/issues/new." + ) + state_dict = {k: v for k, v in state_dict.items() if ".norm" not in k and ".diff" not in k} + + limit_substrings = ["lora_down", "lora_up"] + if any("alpha" in k for k in state_dict): + limit_substrings.append("alpha") + + state_dict = { + _custom_replace(k, limit_substrings): v + for k, v in state_dict.items() + if k.startswith(("lora_unet_", "lora_te_")) + } + + if any("text_projection" in k for k in state_dict): + logger.info( + "`text_projection` keys found in the `state_dict` which are unexpected. " + "So, we will filter out those keys. Open an issue if this is a problem - " + "https://github.com/huggingface/diffusers/issues/new." + ) + state_dict = {k: v for k, v in state_dict.items() if "text_projection" not in k} + if has_mixture: return _convert_mixture_state_dict_to_diffusers(state_dict) @@ -798,6 +990,26 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict): return new_state_dict +def _custom_replace(key: str, substrings: List[str]) -> str: + # Replaces the "."s with "_"s upto the `substrings`. + # Example: + # lora_unet.foo.bar.lora_A.weight -> lora_unet_foo_bar.lora_A.weight + pattern = "(" + "|".join(re.escape(sub) for sub in substrings) + ")" + + match = re.search(pattern, key) + if match: + start_sub = match.start() + if start_sub > 0 and key[start_sub - 1] == ".": + boundary = start_sub - 1 + else: + boundary = start_sub + left = key[:boundary].replace(".", "_") + right = key[boundary:] + return left + right + else: + return key.replace(".", "_") + + def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict): converted_state_dict = {} original_state_dict_keys = list(original_state_dict.keys()) @@ -806,11 +1018,6 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict): inner_dim = 3072 mlp_ratio = 4.0 - def swap_scale_shift(weight): - shift, scale = weight.chunk(2, dim=0) - new_weight = torch.cat([scale, shift], dim=0) - return new_weight - for lora_key in ["lora_A", "lora_B"]: ## time_text_embed.timestep_embedder <- time_in converted_state_dict[ diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 1809a5d56c..9165c46f3c 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -58,23 +58,11 @@ _SET_ADAPTER_SCALE_FN_MAPPING = { } -def _maybe_adjust_config(config): - """ - We may run into some ambiguous configuration values when a model has module names, sharing a common prefix - (`proj_out.weight` and `blocks.transformer.proj_out.weight`, for example) and they have different LoRA ranks. This - method removes the ambiguity by following what is described here: - https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028. - """ - # Track keys that have been explicitly removed to prevent re-adding them. - deleted_keys = set() - +def _maybe_raise_error_for_ambiguity(config): rank_pattern = config["rank_pattern"].copy() target_modules = config["target_modules"] - original_r = config["r"] for key in list(rank_pattern.keys()): - key_rank = rank_pattern[key] - # try to detect ambiguity # `target_modules` can also be a str, in which case this loop would loop # over the chars of the str. The technically correct way to match LoRA keys @@ -82,35 +70,12 @@ def _maybe_adjust_config(config): # But this cuts it for now. exact_matches = [mod for mod in target_modules if mod == key] substring_matches = [mod for mod in target_modules if key in mod and mod != key] - ambiguous_key = key if exact_matches and substring_matches: - # if ambiguous, update the rank associated with the ambiguous key (`proj_out`, for example) - config["r"] = key_rank - # remove the ambiguous key from `rank_pattern` and record it as deleted - del config["rank_pattern"][key] - deleted_keys.add(key) - # For substring matches, add them with the original rank only if they haven't been assigned already - for mod in substring_matches: - if mod not in config["rank_pattern"] and mod not in deleted_keys: - config["rank_pattern"][mod] = original_r - - # Update the rest of the target modules with the original rank if not already set and not deleted - for mod in target_modules: - if mod != ambiguous_key and mod not in config["rank_pattern"] and mod not in deleted_keys: - config["rank_pattern"][mod] = original_r - - # Handle alphas to deal with cases like: - # https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777 - has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"] - if has_different_ranks: - config["lora_alpha"] = config["r"] - alpha_pattern = {} - for module_name, rank in config["rank_pattern"].items(): - alpha_pattern[module_name] = rank - config["alpha_pattern"] = alpha_pattern - - return config + if is_peft_version("<", "0.14.1"): + raise ValueError( + "There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`." + ) class PeftAdapterMixin: @@ -286,16 +251,18 @@ class PeftAdapterMixin: # Cannot figure out rank from lora layers that don't have atleast 2 dimensions. # Bias layers in LoRA only have a single dimension if "lora_B" in key and val.ndim > 1: - # TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged. - rank[key] = val.shape[1] + # Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol. + # We may run into some ambiguous configuration values when a model has module + # names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`, + # for example) and they have different LoRA ranks. + rank[f"^{key}"] = val.shape[1] if network_alphas is not None and len(network_alphas) >= 1: alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) - # TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged. - lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs) + _maybe_raise_error_for_ambiguity(lora_config_kwargs) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 50a4707727..438faa23e5 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -126,6 +126,7 @@ from .state_dict_utils import ( convert_state_dict_to_kohya, convert_state_dict_to_peft, convert_unet_state_dict_to_peft, + state_dict_all_zero, ) from .typing_utils import _get_detailed_type, _is_valid_type diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 62b114ba67..f23fddd286 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -17,9 +17,14 @@ State dict utilities: utility methods for converting state dicts easily import enum +from .import_utils import is_torch_available from .logging import get_logger +if is_torch_available(): + import torch + + logger = get_logger(__name__) @@ -333,3 +338,12 @@ def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs): kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight)) return kohya_ss_state_dict + + +def state_dict_all_zero(state_dict, filter_str=None): + if filter_str is not None: + if isinstance(filter_str, str): + filter_str = [filter_str] + state_dict = {k: v for k, v in state_dict.items() if any(f in k for f in filter_str)} + + return all(torch.all(param == 0).item() for param in state_dict.values())