mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
resolve conflits
This commit is contained in:
@@ -16,7 +16,7 @@ import inspect
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
@@ -25,7 +25,6 @@ from ..utils import (
|
||||
MIN_PEFT_VERSION,
|
||||
USE_PEFT_BACKEND,
|
||||
check_peft_version,
|
||||
convert_control_lora_state_dict_to_peft,
|
||||
convert_unet_state_dict_to_peft,
|
||||
delete_adapter_layers,
|
||||
get_adapter_name,
|
||||
@@ -59,23 +58,11 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
def _maybe_adjust_config(config):
|
||||
"""
|
||||
We may run into some ambiguous configuration values when a model has module names, sharing a common prefix
|
||||
(`proj_out.weight` and `blocks.transformer.proj_out.weight`, for example) and they have different LoRA ranks. This
|
||||
method removes the ambiguity by following what is described here:
|
||||
https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028.
|
||||
"""
|
||||
# Track keys that have been explicitly removed to prevent re-adding them.
|
||||
deleted_keys = set()
|
||||
|
||||
def _maybe_raise_error_for_ambiguity(config):
|
||||
rank_pattern = config["rank_pattern"].copy()
|
||||
target_modules = config["target_modules"]
|
||||
original_r = config["r"]
|
||||
|
||||
for key in list(rank_pattern.keys()):
|
||||
key_rank = rank_pattern[key]
|
||||
|
||||
# try to detect ambiguity
|
||||
# `target_modules` can also be a str, in which case this loop would loop
|
||||
# over the chars of the str. The technically correct way to match LoRA keys
|
||||
@@ -83,62 +70,12 @@ def _maybe_adjust_config(config):
|
||||
# But this cuts it for now.
|
||||
exact_matches = [mod for mod in target_modules if mod == key]
|
||||
substring_matches = [mod for mod in target_modules if key in mod and mod != key]
|
||||
ambiguous_key = key
|
||||
|
||||
if exact_matches and substring_matches:
|
||||
# if ambiguous, update the rank associated with the ambiguous key (`proj_out`, for example)
|
||||
config["r"] = key_rank
|
||||
# remove the ambiguous key from `rank_pattern` and record it as deleted
|
||||
del config["rank_pattern"][key]
|
||||
deleted_keys.add(key)
|
||||
# For substring matches, add them with the original rank only if they haven't been assigned already
|
||||
for mod in substring_matches:
|
||||
if mod not in config["rank_pattern"] and mod not in deleted_keys:
|
||||
config["rank_pattern"][mod] = original_r
|
||||
|
||||
# Update the rest of the target modules with the original rank if not already set and not deleted
|
||||
for mod in target_modules:
|
||||
if mod != ambiguous_key and mod not in config["rank_pattern"] and mod not in deleted_keys:
|
||||
config["rank_pattern"][mod] = original_r
|
||||
|
||||
# Handle alphas to deal with cases like:
|
||||
# https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777
|
||||
has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"]
|
||||
if has_different_ranks:
|
||||
config["lora_alpha"] = config["r"]
|
||||
alpha_pattern = {}
|
||||
for module_name, rank in config["rank_pattern"].items():
|
||||
alpha_pattern[module_name] = rank
|
||||
config["alpha_pattern"] = alpha_pattern
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _maybe_adjust_config_for_control_lora(config):
|
||||
"""
|
||||
"""
|
||||
|
||||
target_modules_before = config["target_modules"]
|
||||
target_modules = []
|
||||
modules_to_save = []
|
||||
|
||||
for module in target_modules_before:
|
||||
if module.endswith("weight"):
|
||||
base_name = ".".join(module.split(".")[:-1])
|
||||
modules_to_save.append(base_name)
|
||||
elif module.endswith("bias"):
|
||||
base_name = ".".join(module.split(".")[:-1])
|
||||
if ".".join([base_name, "weight"]) in target_modules_before:
|
||||
modules_to_save.append(base_name)
|
||||
else:
|
||||
target_modules.append(base_name)
|
||||
else:
|
||||
target_modules.append(module)
|
||||
|
||||
config["target_modules"] = list(set(target_modules))
|
||||
config["modules_to_save"] = list(set(modules_to_save))
|
||||
|
||||
return config
|
||||
if is_peft_version("<", "0.14.1"):
|
||||
raise ValueError(
|
||||
"There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
|
||||
)
|
||||
|
||||
|
||||
class PeftAdapterMixin:
|
||||
@@ -156,6 +93,8 @@ class PeftAdapterMixin:
|
||||
"""
|
||||
|
||||
_hf_peft_config_loaded = False
|
||||
# kwargs for prepare_model_for_compiled_hotswap, if required
|
||||
_prepare_lora_hotswap_kwargs: Optional[dict] = None
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
|
||||
@@ -173,7 +112,9 @@ class PeftAdapterMixin:
|
||||
"""
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs):
|
||||
def load_lora_adapter(
|
||||
self, pretrained_model_name_or_path_or_dict, prefix="transformer", hotswap: bool = False, **kwargs
|
||||
):
|
||||
r"""
|
||||
Loads a LoRA adapter into the underlying model.
|
||||
|
||||
@@ -217,6 +158,29 @@ class PeftAdapterMixin:
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap : (`bool`, *optional*)
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
"""
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
@@ -267,17 +231,15 @@ class PeftAdapterMixin:
|
||||
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
|
||||
if len(state_dict) > 0:
|
||||
if adapter_name in getattr(self, "peft_config", {}):
|
||||
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the model - please select a new adapter name."
|
||||
)
|
||||
|
||||
# Control LoRA from SAI is different from BFL Control LoRA
|
||||
# https://huggingface.co/stabilityai/control-lora/
|
||||
is_control_lora = "lora_controlnet" in state_dict
|
||||
if is_control_lora:
|
||||
del state_dict["lora_controlnet"]
|
||||
state_dict = convert_control_lora_state_dict_to_peft(state_dict)
|
||||
elif adapter_name not in getattr(self, "peft_config", {}) and hotswap:
|
||||
raise ValueError(
|
||||
f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name. "
|
||||
"Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping."
|
||||
)
|
||||
|
||||
# check with first key if is not in peft format
|
||||
first_key = next(iter(state_dict.keys()))
|
||||
@@ -289,18 +251,18 @@ class PeftAdapterMixin:
|
||||
# Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
|
||||
# Bias layers in LoRA only have a single dimension
|
||||
if "lora_B" in key and val.ndim > 1:
|
||||
# TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged.
|
||||
rank[key] = val.shape[1]
|
||||
# Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol.
|
||||
# We may run into some ambiguous configuration values when a model has module
|
||||
# names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`,
|
||||
# for example) and they have different LoRA ranks.
|
||||
rank[f"^{key}"] = val.shape[1]
|
||||
|
||||
if network_alphas is not None and len(network_alphas) >= 1:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
|
||||
# TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged.
|
||||
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
|
||||
if is_control_lora:
|
||||
lora_config_kwargs = _maybe_adjust_config_for_control_lora(lora_config_kwargs)
|
||||
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
@@ -339,11 +301,71 @@ class PeftAdapterMixin:
|
||||
if is_peft_version(">=", "0.13.1"):
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
if hotswap or (self._prepare_lora_hotswap_kwargs is not None):
|
||||
if is_peft_version(">", "0.14.0"):
|
||||
from peft.utils.hotswap import (
|
||||
check_hotswap_configs_compatible,
|
||||
hotswap_adapter_from_state_dict,
|
||||
prepare_model_for_compiled_hotswap,
|
||||
)
|
||||
else:
|
||||
msg = (
|
||||
"Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it "
|
||||
"from source."
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
if hotswap:
|
||||
|
||||
def map_state_dict_for_hotswap(sd):
|
||||
# For hotswapping, we need the adapter name to be present in the state dict keys
|
||||
new_sd = {}
|
||||
for k, v in sd.items():
|
||||
if k.endswith("lora_A.weight") or key.endswith("lora_B.weight"):
|
||||
k = k[: -len(".weight")] + f".{adapter_name}.weight"
|
||||
elif k.endswith("lora_B.bias"): # lora_bias=True option
|
||||
k = k[: -len(".bias")] + f".{adapter_name}.bias"
|
||||
new_sd[k] = v
|
||||
return new_sd
|
||||
|
||||
# To handle scenarios where we cannot successfully set state dict. If it's unsucessful,
|
||||
# we should also delete the `peft_config` associated to the `adapter_name`.
|
||||
try:
|
||||
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
|
||||
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
|
||||
if hotswap:
|
||||
state_dict = map_state_dict_for_hotswap(state_dict)
|
||||
check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config)
|
||||
try:
|
||||
hotswap_adapter_from_state_dict(
|
||||
model=self,
|
||||
state_dict=state_dict,
|
||||
adapter_name=adapter_name,
|
||||
config=lora_config,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Hotswapping {adapter_name} was unsucessful with the following error: \n{e}")
|
||||
raise
|
||||
# the hotswap function raises if there are incompatible keys, so if we reach this point we can set
|
||||
# it to None
|
||||
incompatible_keys = None
|
||||
else:
|
||||
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
|
||||
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
|
||||
|
||||
if self._prepare_lora_hotswap_kwargs is not None:
|
||||
# For hotswapping of compiled models or adapters with different ranks.
|
||||
# If the user called enable_lora_hotswap, we need to ensure it is called:
|
||||
# - after the first adapter was loaded
|
||||
# - before the model is compiled and the 2nd adapter is being hotswapped in
|
||||
# Therefore, it needs to be called here
|
||||
prepare_model_for_compiled_hotswap(
|
||||
self, config=lora_config, **self._prepare_lora_hotswap_kwargs
|
||||
)
|
||||
# We only want to call prepare_model_for_compiled_hotswap once
|
||||
self._prepare_lora_hotswap_kwargs = None
|
||||
|
||||
# Set peft config loaded flag to True if module has been successfully injected and incompatible keys retrieved
|
||||
if not self._hf_peft_config_loaded:
|
||||
self._hf_peft_config_loaded = True
|
||||
except Exception as e:
|
||||
# In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`.
|
||||
if hasattr(self, "peft_config"):
|
||||
@@ -803,3 +825,36 @@ class PeftAdapterMixin:
|
||||
# Pop also the corresponding adapter from the config
|
||||
if hasattr(self, "peft_config"):
|
||||
self.peft_config.pop(adapter_name, None)
|
||||
|
||||
def enable_lora_hotswap(
|
||||
self, target_rank: int = 128, check_compiled: Literal["error", "warn", "ignore"] = "error"
|
||||
) -> None:
|
||||
"""Enables the possibility to hotswap LoRA adapters.
|
||||
|
||||
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
|
||||
the loaded adapters differ.
|
||||
|
||||
Args:
|
||||
target_rank (`int`, *optional*, defaults to `128`):
|
||||
The highest rank among all the adapters that will be loaded.
|
||||
|
||||
check_compiled (`str`, *optional*, defaults to `"error"`):
|
||||
How to handle the case when the model is already compiled, which should generally be avoided. The
|
||||
options are:
|
||||
- "error" (default): raise an error
|
||||
- "warn": issue a warning
|
||||
- "ignore": do nothing
|
||||
"""
|
||||
if getattr(self, "peft_config", {}):
|
||||
if check_compiled == "error":
|
||||
raise RuntimeError("Call `enable_lora_hotswap` before loading the first adapter.")
|
||||
elif check_compiled == "warn":
|
||||
logger.warning(
|
||||
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
|
||||
)
|
||||
elif check_compiled != "ignore":
|
||||
raise ValueError(
|
||||
f"check_compiles should be one of 'error', 'warn', or 'ignore', got '{check_compiled}' instead."
|
||||
)
|
||||
|
||||
self._prepare_lora_hotswap_kwargs = {"target_rank": target_rank, "check_compiled": check_compiled}
|
||||
|
||||
@@ -126,7 +126,7 @@ from .state_dict_utils import (
|
||||
convert_state_dict_to_kohya,
|
||||
convert_state_dict_to_peft,
|
||||
convert_unet_state_dict_to_peft,
|
||||
convert_control_lora_state_dict_to_peft,
|
||||
state_dict_all_zero,
|
||||
)
|
||||
from .typing_utils import _get_detailed_type, _is_valid_type
|
||||
|
||||
|
||||
Reference in New Issue
Block a user