mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] refactor lora loading at the model-level (#11719)
* factor out stuff from load_lora_adapter(). * simplifying text encoder lora loading. * fix peft.py * fix logging locations. * formatting * fix * update * update * update
This commit is contained in:
@@ -34,7 +34,6 @@ from ..utils import (
|
||||
delete_adapter_layers,
|
||||
deprecate,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_accelerate_available,
|
||||
is_peft_available,
|
||||
is_peft_version,
|
||||
@@ -46,14 +45,13 @@ from ..utils import (
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from ..utils.peft_utils import _create_lora_config
|
||||
from ..utils.state_dict_utils import _load_sft_state_dict_metadata
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
|
||||
|
||||
if is_peft_available():
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
@@ -352,8 +350,6 @@ def _load_lora_into_text_encoder(
|
||||
)
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
|
||||
# their prefixes.
|
||||
@@ -377,60 +373,25 @@ def _load_lora_into_text_encoder(
|
||||
# convert state dict
|
||||
state_dict = convert_state_dict_to_peft(state_dict)
|
||||
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in state_dict:
|
||||
continue
|
||||
rank[rank_key] = state_dict[rank_key].shape[1]
|
||||
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
for module in ("fc1", "fc2"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in state_dict:
|
||||
continue
|
||||
rank[rank_key] = state_dict[rank_key].shape[1]
|
||||
for name, _ in text_encoder.named_modules():
|
||||
if name.endswith((".q_proj", ".k_proj", ".v_proj", ".out_proj", ".fc1", ".fc2")):
|
||||
rank_key = f"{name}.lora_B.weight"
|
||||
if rank_key in state_dict:
|
||||
rank[rank_key] = state_dict[rank_key].shape[1]
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
if metadata is not None:
|
||||
lora_config_kwargs = metadata
|
||||
else:
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
|
||||
if "lora_bias" in lora_config_kwargs:
|
||||
if lora_config_kwargs["lora_bias"]:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
|
||||
try:
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
except TypeError as e:
|
||||
raise TypeError("`LoraConfig` class could not be instantiated.") from e
|
||||
# create `LoraConfig`
|
||||
lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank, is_unet=False)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
# <Unsafe code
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
|
||||
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
@@ -442,7 +403,6 @@ def _load_lora_into_text_encoder(
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
# Offload back.
|
||||
@@ -453,10 +413,11 @@ def _load_lora_into_text_encoder(
|
||||
# Unsafe code />
|
||||
|
||||
if prefix is not None and not state_dict:
|
||||
model_class_name = text_encoder.__class__.__name__
|
||||
logger.warning(
|
||||
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. "
|
||||
f"No LoRA keys associated to {model_class_name} found with the {prefix=}. "
|
||||
"This is safe to ignore if LoRA state dict didn't originally have any "
|
||||
f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` "
|
||||
f"{model_class_name} related params. You can also try specifying `prefix=None` "
|
||||
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
|
||||
"https://github.com/huggingface/diffusers/issues/new"
|
||||
)
|
||||
|
||||
@@ -29,13 +29,13 @@ from ..utils import (
|
||||
convert_unet_state_dict_to_peft,
|
||||
delete_adapter_layers,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_peft_available,
|
||||
is_peft_version,
|
||||
logging,
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from ..utils.peft_utils import _create_lora_config, _maybe_warn_for_unhandled_keys
|
||||
from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
|
||||
from .unet_loader_utils import _maybe_expand_lora_scales
|
||||
|
||||
@@ -64,26 +64,6 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
def _maybe_raise_error_for_ambiguity(config):
|
||||
rank_pattern = config["rank_pattern"].copy()
|
||||
target_modules = config["target_modules"]
|
||||
|
||||
for key in list(rank_pattern.keys()):
|
||||
# try to detect ambiguity
|
||||
# `target_modules` can also be a str, in which case this loop would loop
|
||||
# over the chars of the str. The technically correct way to match LoRA keys
|
||||
# in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
|
||||
# But this cuts it for now.
|
||||
exact_matches = [mod for mod in target_modules if mod == key]
|
||||
substring_matches = [mod for mod in target_modules if key in mod and mod != key]
|
||||
|
||||
if exact_matches and substring_matches:
|
||||
if is_peft_version("<", "0.14.1"):
|
||||
raise ValueError(
|
||||
"There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
|
||||
)
|
||||
|
||||
|
||||
class PeftAdapterMixin:
|
||||
"""
|
||||
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
|
||||
@@ -191,7 +171,7 @@ class PeftAdapterMixin:
|
||||
LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to
|
||||
initialize `LoraConfig`.
|
||||
"""
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
from peft import inject_adapter_in_model, set_peft_model_state_dict
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
@@ -216,7 +196,6 @@ class PeftAdapterMixin:
|
||||
)
|
||||
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
|
||||
state_dict, metadata = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
@@ -275,38 +254,8 @@ class PeftAdapterMixin:
|
||||
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
if metadata is not None:
|
||||
lora_config_kwargs = metadata
|
||||
else:
|
||||
lora_config_kwargs = get_peft_kwargs(
|
||||
rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict
|
||||
)
|
||||
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
|
||||
if "lora_bias" in lora_config_kwargs:
|
||||
if lora_config_kwargs["lora_bias"]:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.14.0 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
|
||||
try:
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
except TypeError as e:
|
||||
raise TypeError("`LoraConfig` class could not be instantiated.") from e
|
||||
# create LoraConfig
|
||||
lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
@@ -317,9 +266,8 @@ class PeftAdapterMixin:
|
||||
# Now we remove any existing hooks to `_pipeline`.
|
||||
|
||||
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
# otherwise loading LoRA weights will lead to an error.
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
|
||||
|
||||
peft_kwargs = {}
|
||||
if is_peft_version(">=", "0.13.1"):
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
@@ -403,30 +351,7 @@ class PeftAdapterMixin:
|
||||
logger.error(f"Loading {adapter_name} was unsuccessful with the following error: \n{e}")
|
||||
raise
|
||||
|
||||
warn_msg = ""
|
||||
if incompatible_keys is not None:
|
||||
# Check only for unexpected keys.
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
|
||||
if lora_unexpected_keys:
|
||||
warn_msg = (
|
||||
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
|
||||
f" {', '.join(lora_unexpected_keys)}. "
|
||||
)
|
||||
|
||||
# Filter missing keys specific to the current adapter.
|
||||
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
||||
if missing_keys:
|
||||
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
||||
if lora_missing_keys:
|
||||
warn_msg += (
|
||||
f"Loading adapter weights from state_dict led to missing keys in the model:"
|
||||
f" {', '.join(lora_missing_keys)}."
|
||||
)
|
||||
|
||||
if warn_msg:
|
||||
logger.warning(warn_msg)
|
||||
_maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
@@ -436,10 +361,11 @@ class PeftAdapterMixin:
|
||||
# Unsafe code />
|
||||
|
||||
if prefix is not None and not state_dict:
|
||||
model_class_name = self.__class__.__name__
|
||||
logger.warning(
|
||||
f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. "
|
||||
f"No LoRA keys associated to {model_class_name} found with the {prefix=}. "
|
||||
"This is safe to ignore if LoRA state dict didn't originally have any "
|
||||
f"{self.__class__.__name__} related params. You can also try specifying `prefix=None` "
|
||||
f"{model_class_name} related params. You can also try specifying `prefix=None` "
|
||||
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
|
||||
"https://github.com/huggingface/diffusers/issues/new"
|
||||
)
|
||||
|
||||
@@ -21,9 +21,12 @@ from typing import Optional
|
||||
|
||||
from packaging import version
|
||||
|
||||
from .import_utils import is_peft_available, is_torch_available
|
||||
from . import logging
|
||||
from .import_utils import is_peft_available, is_peft_version, is_torch_available
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
@@ -288,3 +291,83 @@ def check_peft_version(min_version: str) -> None:
|
||||
f"The version of PEFT you are using is not compatible, please use a version that is greater"
|
||||
f" than {min_version}"
|
||||
)
|
||||
|
||||
|
||||
def _create_lora_config(
|
||||
state_dict,
|
||||
network_alphas,
|
||||
metadata,
|
||||
rank_pattern_dict,
|
||||
is_unet: bool = True,
|
||||
):
|
||||
from peft import LoraConfig
|
||||
|
||||
if metadata is not None:
|
||||
lora_config_kwargs = metadata
|
||||
else:
|
||||
lora_config_kwargs = get_peft_kwargs(
|
||||
rank_pattern_dict, network_alpha_dict=network_alphas, peft_state_dict=state_dict, is_unet=is_unet
|
||||
)
|
||||
|
||||
_maybe_raise_error_for_ambiguous_keys(lora_config_kwargs)
|
||||
|
||||
# Version checks for DoRA and lora_bias
|
||||
if "use_dora" in lora_config_kwargs and lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError("DoRA requires PEFT >= 0.9.0. Please upgrade.")
|
||||
|
||||
if "lora_bias" in lora_config_kwargs and lora_config_kwargs["lora_bias"]:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
raise ValueError("lora_bias requires PEFT >= 0.14.0. Please upgrade.")
|
||||
|
||||
try:
|
||||
return LoraConfig(**lora_config_kwargs)
|
||||
except TypeError as e:
|
||||
raise TypeError("`LoraConfig` class could not be instantiated.") from e
|
||||
|
||||
|
||||
def _maybe_raise_error_for_ambiguous_keys(config):
|
||||
rank_pattern = config["rank_pattern"].copy()
|
||||
target_modules = config["target_modules"]
|
||||
|
||||
for key in list(rank_pattern.keys()):
|
||||
# try to detect ambiguity
|
||||
# `target_modules` can also be a str, in which case this loop would loop
|
||||
# over the chars of the str. The technically correct way to match LoRA keys
|
||||
# in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
|
||||
# But this cuts it for now.
|
||||
exact_matches = [mod for mod in target_modules if mod == key]
|
||||
substring_matches = [mod for mod in target_modules if key in mod and mod != key]
|
||||
|
||||
if exact_matches and substring_matches:
|
||||
if is_peft_version("<", "0.14.1"):
|
||||
raise ValueError(
|
||||
"There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
|
||||
)
|
||||
|
||||
|
||||
def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
|
||||
warn_msg = ""
|
||||
if incompatible_keys is not None:
|
||||
# Check only for unexpected keys.
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
|
||||
if lora_unexpected_keys:
|
||||
warn_msg = (
|
||||
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
|
||||
f" {', '.join(lora_unexpected_keys)}. "
|
||||
)
|
||||
|
||||
# Filter missing keys specific to the current adapter.
|
||||
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
||||
if missing_keys:
|
||||
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
||||
if lora_missing_keys:
|
||||
warn_msg += (
|
||||
f"Loading adapter weights from state_dict led to missing keys in the model:"
|
||||
f" {', '.join(lora_missing_keys)}."
|
||||
)
|
||||
|
||||
if warn_msg:
|
||||
logger.warning(warn_msg)
|
||||
|
||||
@@ -1794,7 +1794,7 @@ class PeftLoraLoaderMixinTests:
|
||||
missing_key = [k for k in state_dict if "lora_A" in k][0]
|
||||
del state_dict[missing_key]
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.peft")
|
||||
logger = logging.get_logger("diffusers.utils.peft_utils")
|
||||
logger.setLevel(30)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(state_dict)
|
||||
@@ -1829,7 +1829,7 @@ class PeftLoraLoaderMixinTests:
|
||||
unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat"
|
||||
state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.peft")
|
||||
logger = logging.get_logger("diffusers.utils.peft_utils")
|
||||
logger.setLevel(30)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(state_dict)
|
||||
@@ -2006,9 +2006,6 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
denoiser_lora_config.lora_bias = False
|
||||
|
||||
Reference in New Issue
Block a user