1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

resolve conflits

This commit is contained in:
lavinal712
2025-04-09 07:41:14 +00:00
parent ce2b34bba7
commit 6a1ff82d08
2 changed files with 142 additions and 87 deletions

View File

@@ -16,7 +16,7 @@ import inspect
import os
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, List, Literal, Optional, Union
import safetensors
import torch
@@ -25,7 +25,6 @@ from ..utils import (
MIN_PEFT_VERSION,
USE_PEFT_BACKEND,
check_peft_version,
convert_control_lora_state_dict_to_peft,
convert_unet_state_dict_to_peft,
delete_adapter_layers,
get_adapter_name,
@@ -59,23 +58,11 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
}
def _maybe_adjust_config(config):
"""
We may run into some ambiguous configuration values when a model has module names, sharing a common prefix
(`proj_out.weight` and `blocks.transformer.proj_out.weight`, for example) and they have different LoRA ranks. This
method removes the ambiguity by following what is described here:
https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028.
"""
# Track keys that have been explicitly removed to prevent re-adding them.
deleted_keys = set()
def _maybe_raise_error_for_ambiguity(config):
rank_pattern = config["rank_pattern"].copy()
target_modules = config["target_modules"]
original_r = config["r"]
for key in list(rank_pattern.keys()):
key_rank = rank_pattern[key]
# try to detect ambiguity
# `target_modules` can also be a str, in which case this loop would loop
# over the chars of the str. The technically correct way to match LoRA keys
@@ -83,62 +70,12 @@ def _maybe_adjust_config(config):
# But this cuts it for now.
exact_matches = [mod for mod in target_modules if mod == key]
substring_matches = [mod for mod in target_modules if key in mod and mod != key]
ambiguous_key = key
if exact_matches and substring_matches:
# if ambiguous, update the rank associated with the ambiguous key (`proj_out`, for example)
config["r"] = key_rank
# remove the ambiguous key from `rank_pattern` and record it as deleted
del config["rank_pattern"][key]
deleted_keys.add(key)
# For substring matches, add them with the original rank only if they haven't been assigned already
for mod in substring_matches:
if mod not in config["rank_pattern"] and mod not in deleted_keys:
config["rank_pattern"][mod] = original_r
# Update the rest of the target modules with the original rank if not already set and not deleted
for mod in target_modules:
if mod != ambiguous_key and mod not in config["rank_pattern"] and mod not in deleted_keys:
config["rank_pattern"][mod] = original_r
# Handle alphas to deal with cases like:
# https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777
has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"]
if has_different_ranks:
config["lora_alpha"] = config["r"]
alpha_pattern = {}
for module_name, rank in config["rank_pattern"].items():
alpha_pattern[module_name] = rank
config["alpha_pattern"] = alpha_pattern
return config
def _maybe_adjust_config_for_control_lora(config):
"""
"""
target_modules_before = config["target_modules"]
target_modules = []
modules_to_save = []
for module in target_modules_before:
if module.endswith("weight"):
base_name = ".".join(module.split(".")[:-1])
modules_to_save.append(base_name)
elif module.endswith("bias"):
base_name = ".".join(module.split(".")[:-1])
if ".".join([base_name, "weight"]) in target_modules_before:
modules_to_save.append(base_name)
else:
target_modules.append(base_name)
else:
target_modules.append(module)
config["target_modules"] = list(set(target_modules))
config["modules_to_save"] = list(set(modules_to_save))
return config
if is_peft_version("<", "0.14.1"):
raise ValueError(
"There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
)
class PeftAdapterMixin:
@@ -156,6 +93,8 @@ class PeftAdapterMixin:
"""
_hf_peft_config_loaded = False
# kwargs for prepare_model_for_compiled_hotswap, if required
_prepare_lora_hotswap_kwargs: Optional[dict] = None
@classmethod
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
@@ -173,7 +112,9 @@ class PeftAdapterMixin:
"""
return _func_optionally_disable_offloading(_pipeline=_pipeline)
def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs):
def load_lora_adapter(
self, pretrained_model_name_or_path_or_dict, prefix="transformer", hotswap: bool = False, **kwargs
):
r"""
Loads a LoRA adapter into the underlying model.
@@ -217,6 +158,29 @@ class PeftAdapterMixin:
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
in-place. This means that, instead of loading an additional adapter, this will take the existing
adapter weights and replace them with the weights of the new adapter. This can be faster and more
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
torch.compile, loading the new adapter does not require recompilation of the model. When using
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:
```py
pipeline = ... # load diffusers pipeline
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
pipeline.enable_lora_hotswap(target_rank=max_rank)
pipeline.load_lora_weights(file_name)
# optionally compile the model now
```
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
"""
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer
@@ -267,17 +231,15 @@ class PeftAdapterMixin:
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
if len(state_dict) > 0:
if adapter_name in getattr(self, "peft_config", {}):
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
raise ValueError(
f"Adapter name {adapter_name} already in use in the model - please select a new adapter name."
)
# Control LoRA from SAI is different from BFL Control LoRA
# https://huggingface.co/stabilityai/control-lora/
is_control_lora = "lora_controlnet" in state_dict
if is_control_lora:
del state_dict["lora_controlnet"]
state_dict = convert_control_lora_state_dict_to_peft(state_dict)
elif adapter_name not in getattr(self, "peft_config", {}) and hotswap:
raise ValueError(
f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name. "
"Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping."
)
# check with first key if is not in peft format
first_key = next(iter(state_dict.keys()))
@@ -289,18 +251,18 @@ class PeftAdapterMixin:
# Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
# Bias layers in LoRA only have a single dimension
if "lora_B" in key and val.ndim > 1:
# TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged.
rank[key] = val.shape[1]
# Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol.
# We may run into some ambiguous configuration values when a model has module
# names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`,
# for example) and they have different LoRA ranks.
rank[f"^{key}"] = val.shape[1]
if network_alphas is not None and len(network_alphas) >= 1:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
# TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged.
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
if is_control_lora:
lora_config_kwargs = _maybe_adjust_config_for_control_lora(lora_config_kwargs)
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
@@ -339,11 +301,71 @@ class PeftAdapterMixin:
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
if hotswap or (self._prepare_lora_hotswap_kwargs is not None):
if is_peft_version(">", "0.14.0"):
from peft.utils.hotswap import (
check_hotswap_configs_compatible,
hotswap_adapter_from_state_dict,
prepare_model_for_compiled_hotswap,
)
else:
msg = (
"Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it "
"from source."
)
raise ImportError(msg)
if hotswap:
def map_state_dict_for_hotswap(sd):
# For hotswapping, we need the adapter name to be present in the state dict keys
new_sd = {}
for k, v in sd.items():
if k.endswith("lora_A.weight") or key.endswith("lora_B.weight"):
k = k[: -len(".weight")] + f".{adapter_name}.weight"
elif k.endswith("lora_B.bias"): # lora_bias=True option
k = k[: -len(".bias")] + f".{adapter_name}.bias"
new_sd[k] = v
return new_sd
# To handle scenarios where we cannot successfully set state dict. If it's unsucessful,
# we should also delete the `peft_config` associated to the `adapter_name`.
try:
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
if hotswap:
state_dict = map_state_dict_for_hotswap(state_dict)
check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config)
try:
hotswap_adapter_from_state_dict(
model=self,
state_dict=state_dict,
adapter_name=adapter_name,
config=lora_config,
)
except Exception as e:
logger.error(f"Hotswapping {adapter_name} was unsucessful with the following error: \n{e}")
raise
# the hotswap function raises if there are incompatible keys, so if we reach this point we can set
# it to None
incompatible_keys = None
else:
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
if self._prepare_lora_hotswap_kwargs is not None:
# For hotswapping of compiled models or adapters with different ranks.
# If the user called enable_lora_hotswap, we need to ensure it is called:
# - after the first adapter was loaded
# - before the model is compiled and the 2nd adapter is being hotswapped in
# Therefore, it needs to be called here
prepare_model_for_compiled_hotswap(
self, config=lora_config, **self._prepare_lora_hotswap_kwargs
)
# We only want to call prepare_model_for_compiled_hotswap once
self._prepare_lora_hotswap_kwargs = None
# Set peft config loaded flag to True if module has been successfully injected and incompatible keys retrieved
if not self._hf_peft_config_loaded:
self._hf_peft_config_loaded = True
except Exception as e:
# In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`.
if hasattr(self, "peft_config"):
@@ -803,3 +825,36 @@ class PeftAdapterMixin:
# Pop also the corresponding adapter from the config
if hasattr(self, "peft_config"):
self.peft_config.pop(adapter_name, None)
def enable_lora_hotswap(
self, target_rank: int = 128, check_compiled: Literal["error", "warn", "ignore"] = "error"
) -> None:
"""Enables the possibility to hotswap LoRA adapters.
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
the loaded adapters differ.
Args:
target_rank (`int`, *optional*, defaults to `128`):
The highest rank among all the adapters that will be loaded.
check_compiled (`str`, *optional*, defaults to `"error"`):
How to handle the case when the model is already compiled, which should generally be avoided. The
options are:
- "error" (default): raise an error
- "warn": issue a warning
- "ignore": do nothing
"""
if getattr(self, "peft_config", {}):
if check_compiled == "error":
raise RuntimeError("Call `enable_lora_hotswap` before loading the first adapter.")
elif check_compiled == "warn":
logger.warning(
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
)
elif check_compiled != "ignore":
raise ValueError(
f"check_compiles should be one of 'error', 'warn', or 'ignore', got '{check_compiled}' instead."
)
self._prepare_lora_hotswap_kwargs = {"target_rank": target_rank, "check_compiled": check_compiled}

View File

@@ -126,7 +126,7 @@ from .state_dict_utils import (
convert_state_dict_to_kohya,
convert_state_dict_to_peft,
convert_unet_state_dict_to_peft,
convert_control_lora_state_dict_to_peft,
state_dict_all_zero,
)
from .typing_utils import _get_detailed_type, _is_valid_type