mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
change peft.py
This commit is contained in:
@@ -25,7 +25,6 @@ from ..utils import (
|
||||
MIN_PEFT_VERSION,
|
||||
USE_PEFT_BACKEND,
|
||||
check_peft_version,
|
||||
convert_control_lora_state_dict_to_peft,
|
||||
convert_unet_state_dict_to_peft,
|
||||
delete_adapter_layers,
|
||||
get_adapter_name,
|
||||
@@ -53,9 +52,12 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"SanaTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"AuraFlowTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"WanTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
|
||||
}
|
||||
|
||||
|
||||
@@ -79,33 +81,6 @@ def _maybe_raise_error_for_ambiguity(config):
|
||||
)
|
||||
|
||||
|
||||
def _maybe_adjust_config_for_control_lora(config):
|
||||
"""
|
||||
"""
|
||||
|
||||
target_modules_before = config["target_modules"]
|
||||
target_modules = []
|
||||
modules_to_save = []
|
||||
|
||||
for module in target_modules_before:
|
||||
if module.endswith("weight"):
|
||||
base_name = ".".join(module.split(".")[:-1])
|
||||
modules_to_save.append(base_name)
|
||||
elif module.endswith("bias"):
|
||||
base_name = ".".join(module.split(".")[:-1])
|
||||
if ".".join([base_name, "weight"]) in target_modules_before:
|
||||
modules_to_save.append(base_name)
|
||||
else:
|
||||
target_modules.append(base_name)
|
||||
else:
|
||||
target_modules.append(module)
|
||||
|
||||
config["target_modules"] = list(set(target_modules))
|
||||
config["modules_to_save"] = list(set(modules_to_save))
|
||||
|
||||
return config
|
||||
|
||||
|
||||
class PeftAdapterMixin:
|
||||
"""
|
||||
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
|
||||
@@ -256,7 +231,7 @@ class PeftAdapterMixin:
|
||||
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
|
||||
|
||||
if prefix is not None:
|
||||
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
|
||||
if len(state_dict) > 0:
|
||||
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
|
||||
@@ -269,13 +244,6 @@ class PeftAdapterMixin:
|
||||
"Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping."
|
||||
)
|
||||
|
||||
# Control LoRA from SAI is different from BFL Control LoRA
|
||||
# https://huggingface.co/stabilityai/control-lora/
|
||||
is_control_lora = "lora_controlnet" in state_dict
|
||||
if is_control_lora:
|
||||
del state_dict["lora_controlnet"]
|
||||
state_dict = convert_control_lora_state_dict_to_peft(state_dict)
|
||||
|
||||
# check with first key if is not in peft format
|
||||
first_key = next(iter(state_dict.keys()))
|
||||
if "lora_A" not in first_key:
|
||||
@@ -294,12 +262,12 @@ class PeftAdapterMixin:
|
||||
|
||||
if network_alphas is not None and len(network_alphas) >= 1:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
network_alphas = {
|
||||
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
|
||||
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
|
||||
if is_control_lora:
|
||||
lora_config_kwargs = _maybe_adjust_config_for_control_lora(lora_config_kwargs)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
@@ -365,7 +333,7 @@ class PeftAdapterMixin:
|
||||
new_sd[k] = v
|
||||
return new_sd
|
||||
|
||||
# To handle scenarios where we cannot successfully set state dict. If it's unsucessful,
|
||||
# To handle scenarios where we cannot successfully set state dict. If it's unsuccessful,
|
||||
# we should also delete the `peft_config` associated to the `adapter_name`.
|
||||
try:
|
||||
if hotswap:
|
||||
@@ -379,7 +347,7 @@ class PeftAdapterMixin:
|
||||
config=lora_config,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Hotswapping {adapter_name} was unsucessful with the following error: \n{e}")
|
||||
logger.error(f"Hotswapping {adapter_name} was unsuccessful with the following error: \n{e}")
|
||||
raise
|
||||
# the hotswap function raises if there are incompatible keys, so if we reach this point we can set
|
||||
# it to None
|
||||
@@ -414,7 +382,7 @@ class PeftAdapterMixin:
|
||||
module.delete_adapter(adapter_name)
|
||||
|
||||
self.peft_config.pop(adapter_name)
|
||||
logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}")
|
||||
logger.error(f"Loading {adapter_name} was unsuccessful with the following error: \n{e}")
|
||||
raise
|
||||
|
||||
warn_msg = ""
|
||||
@@ -747,7 +715,7 @@ class PeftAdapterMixin:
|
||||
if self.lora_scale != 1.0:
|
||||
module.scale_layer(self.lora_scale)
|
||||
|
||||
# For BC with prevous PEFT versions, we need to check the signature
|
||||
# For BC with previous PEFT versions, we need to check the signature
|
||||
# of the `merge` method to see if it supports the `adapter_names` argument.
|
||||
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
||||
if "adapter_names" in supported_merge_kwargs:
|
||||
|
||||
Reference in New Issue
Block a user