From 6a1ff82d0830484202cf3166229add29e4626609 Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Wed, 9 Apr 2025 07:41:14 +0000 Subject: [PATCH] resolve conflits --- src/diffusers/loaders/peft.py | 227 ++++++++++++++++++++------------ src/diffusers/utils/__init__.py | 2 +- 2 files changed, 142 insertions(+), 87 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 3aa06ffb8e..9165c46f3c 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -16,7 +16,7 @@ import inspect import os from functools import partial from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Dict, List, Literal, Optional, Union import safetensors import torch @@ -25,7 +25,6 @@ from ..utils import ( MIN_PEFT_VERSION, USE_PEFT_BACKEND, check_peft_version, - convert_control_lora_state_dict_to_peft, convert_unet_state_dict_to_peft, delete_adapter_layers, get_adapter_name, @@ -59,23 +58,11 @@ _SET_ADAPTER_SCALE_FN_MAPPING = { } -def _maybe_adjust_config(config): - """ - We may run into some ambiguous configuration values when a model has module names, sharing a common prefix - (`proj_out.weight` and `blocks.transformer.proj_out.weight`, for example) and they have different LoRA ranks. This - method removes the ambiguity by following what is described here: - https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028. - """ - # Track keys that have been explicitly removed to prevent re-adding them. - deleted_keys = set() - +def _maybe_raise_error_for_ambiguity(config): rank_pattern = config["rank_pattern"].copy() target_modules = config["target_modules"] - original_r = config["r"] for key in list(rank_pattern.keys()): - key_rank = rank_pattern[key] - # try to detect ambiguity # `target_modules` can also be a str, in which case this loop would loop # over the chars of the str. The technically correct way to match LoRA keys @@ -83,62 +70,12 @@ def _maybe_adjust_config(config): # But this cuts it for now. exact_matches = [mod for mod in target_modules if mod == key] substring_matches = [mod for mod in target_modules if key in mod and mod != key] - ambiguous_key = key if exact_matches and substring_matches: - # if ambiguous, update the rank associated with the ambiguous key (`proj_out`, for example) - config["r"] = key_rank - # remove the ambiguous key from `rank_pattern` and record it as deleted - del config["rank_pattern"][key] - deleted_keys.add(key) - # For substring matches, add them with the original rank only if they haven't been assigned already - for mod in substring_matches: - if mod not in config["rank_pattern"] and mod not in deleted_keys: - config["rank_pattern"][mod] = original_r - - # Update the rest of the target modules with the original rank if not already set and not deleted - for mod in target_modules: - if mod != ambiguous_key and mod not in config["rank_pattern"] and mod not in deleted_keys: - config["rank_pattern"][mod] = original_r - - # Handle alphas to deal with cases like: - # https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777 - has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"] - if has_different_ranks: - config["lora_alpha"] = config["r"] - alpha_pattern = {} - for module_name, rank in config["rank_pattern"].items(): - alpha_pattern[module_name] = rank - config["alpha_pattern"] = alpha_pattern - - return config - - -def _maybe_adjust_config_for_control_lora(config): - """ - """ - - target_modules_before = config["target_modules"] - target_modules = [] - modules_to_save = [] - - for module in target_modules_before: - if module.endswith("weight"): - base_name = ".".join(module.split(".")[:-1]) - modules_to_save.append(base_name) - elif module.endswith("bias"): - base_name = ".".join(module.split(".")[:-1]) - if ".".join([base_name, "weight"]) in target_modules_before: - modules_to_save.append(base_name) - else: - target_modules.append(base_name) - else: - target_modules.append(module) - - config["target_modules"] = list(set(target_modules)) - config["modules_to_save"] = list(set(modules_to_save)) - - return config + if is_peft_version("<", "0.14.1"): + raise ValueError( + "There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`." + ) class PeftAdapterMixin: @@ -156,6 +93,8 @@ class PeftAdapterMixin: """ _hf_peft_config_loaded = False + # kwargs for prepare_model_for_compiled_hotswap, if required + _prepare_lora_hotswap_kwargs: Optional[dict] = None @classmethod # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading @@ -173,7 +112,9 @@ class PeftAdapterMixin: """ return _func_optionally_disable_offloading(_pipeline=_pipeline) - def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs): + def load_lora_adapter( + self, pretrained_model_name_or_path_or_dict, prefix="transformer", hotswap: bool = False, **kwargs + ): r""" Loads a LoRA adapter into the underlying model. @@ -217,6 +158,29 @@ class PeftAdapterMixin: low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft.tuners.tuners_utils import BaseTunerLayer @@ -267,17 +231,15 @@ class PeftAdapterMixin: state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} if len(state_dict) > 0: - if adapter_name in getattr(self, "peft_config", {}): + if adapter_name in getattr(self, "peft_config", {}) and not hotswap: raise ValueError( f"Adapter name {adapter_name} already in use in the model - please select a new adapter name." ) - - # Control LoRA from SAI is different from BFL Control LoRA - # https://huggingface.co/stabilityai/control-lora/ - is_control_lora = "lora_controlnet" in state_dict - if is_control_lora: - del state_dict["lora_controlnet"] - state_dict = convert_control_lora_state_dict_to_peft(state_dict) + elif adapter_name not in getattr(self, "peft_config", {}) and hotswap: + raise ValueError( + f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name. " + "Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping." + ) # check with first key if is not in peft format first_key = next(iter(state_dict.keys())) @@ -289,18 +251,18 @@ class PeftAdapterMixin: # Cannot figure out rank from lora layers that don't have atleast 2 dimensions. # Bias layers in LoRA only have a single dimension if "lora_B" in key and val.ndim > 1: - # TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged. - rank[key] = val.shape[1] + # Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol. + # We may run into some ambiguous configuration values when a model has module + # names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`, + # for example) and they have different LoRA ranks. + rank[f"^{key}"] = val.shape[1] if network_alphas is not None and len(network_alphas) >= 1: alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")] network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) - # TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged. - lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs) - if is_control_lora: - lora_config_kwargs = _maybe_adjust_config_for_control_lora(lora_config_kwargs) + _maybe_raise_error_for_ambiguity(lora_config_kwargs) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: @@ -339,11 +301,71 @@ class PeftAdapterMixin: if is_peft_version(">=", "0.13.1"): peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + if hotswap or (self._prepare_lora_hotswap_kwargs is not None): + if is_peft_version(">", "0.14.0"): + from peft.utils.hotswap import ( + check_hotswap_configs_compatible, + hotswap_adapter_from_state_dict, + prepare_model_for_compiled_hotswap, + ) + else: + msg = ( + "Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it " + "from source." + ) + raise ImportError(msg) + + if hotswap: + + def map_state_dict_for_hotswap(sd): + # For hotswapping, we need the adapter name to be present in the state dict keys + new_sd = {} + for k, v in sd.items(): + if k.endswith("lora_A.weight") or key.endswith("lora_B.weight"): + k = k[: -len(".weight")] + f".{adapter_name}.weight" + elif k.endswith("lora_B.bias"): # lora_bias=True option + k = k[: -len(".bias")] + f".{adapter_name}.bias" + new_sd[k] = v + return new_sd + # To handle scenarios where we cannot successfully set state dict. If it's unsucessful, # we should also delete the `peft_config` associated to the `adapter_name`. try: - inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) - incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) + if hotswap: + state_dict = map_state_dict_for_hotswap(state_dict) + check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config) + try: + hotswap_adapter_from_state_dict( + model=self, + state_dict=state_dict, + adapter_name=adapter_name, + config=lora_config, + ) + except Exception as e: + logger.error(f"Hotswapping {adapter_name} was unsucessful with the following error: \n{e}") + raise + # the hotswap function raises if there are incompatible keys, so if we reach this point we can set + # it to None + incompatible_keys = None + else: + inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) + incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) + + if self._prepare_lora_hotswap_kwargs is not None: + # For hotswapping of compiled models or adapters with different ranks. + # If the user called enable_lora_hotswap, we need to ensure it is called: + # - after the first adapter was loaded + # - before the model is compiled and the 2nd adapter is being hotswapped in + # Therefore, it needs to be called here + prepare_model_for_compiled_hotswap( + self, config=lora_config, **self._prepare_lora_hotswap_kwargs + ) + # We only want to call prepare_model_for_compiled_hotswap once + self._prepare_lora_hotswap_kwargs = None + + # Set peft config loaded flag to True if module has been successfully injected and incompatible keys retrieved + if not self._hf_peft_config_loaded: + self._hf_peft_config_loaded = True except Exception as e: # In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`. if hasattr(self, "peft_config"): @@ -803,3 +825,36 @@ class PeftAdapterMixin: # Pop also the corresponding adapter from the config if hasattr(self, "peft_config"): self.peft_config.pop(adapter_name, None) + + def enable_lora_hotswap( + self, target_rank: int = 128, check_compiled: Literal["error", "warn", "ignore"] = "error" + ) -> None: + """Enables the possibility to hotswap LoRA adapters. + + Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of + the loaded adapters differ. + + Args: + target_rank (`int`, *optional*, defaults to `128`): + The highest rank among all the adapters that will be loaded. + + check_compiled (`str`, *optional*, defaults to `"error"`): + How to handle the case when the model is already compiled, which should generally be avoided. The + options are: + - "error" (default): raise an error + - "warn": issue a warning + - "ignore": do nothing + """ + if getattr(self, "peft_config", {}): + if check_compiled == "error": + raise RuntimeError("Call `enable_lora_hotswap` before loading the first adapter.") + elif check_compiled == "warn": + logger.warning( + "It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation." + ) + elif check_compiled != "ignore": + raise ValueError( + f"check_compiles should be one of 'error', 'warn', or 'ignore', got '{check_compiled}' instead." + ) + + self._prepare_lora_hotswap_kwargs = {"target_rank": target_rank, "check_compiled": check_compiled} diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 2c17b7ca75..438faa23e5 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -126,7 +126,7 @@ from .state_dict_utils import ( convert_state_dict_to_kohya, convert_state_dict_to_peft, convert_unet_state_dict_to_peft, - convert_control_lora_state_dict_to_peft, + state_dict_all_zero, ) from .typing_utils import _get_detailed_type, _is_valid_type