1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

change peft.py

This commit is contained in:
lavinal712
2025-05-29 14:23:41 +00:00
parent 6fff794e59
commit 63bafc88cd

View File

@@ -25,7 +25,6 @@ from ..utils import (
MIN_PEFT_VERSION,
USE_PEFT_BACKEND,
check_peft_version,
convert_control_lora_state_dict_to_peft,
convert_unet_state_dict_to_peft,
delete_adapter_layers,
get_adapter_name,
@@ -53,9 +52,12 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
"SanaTransformer2DModel": lambda model_cls, weights: weights,
"AuraFlowTransformer2DModel": lambda model_cls, weights: weights,
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
"WanTransformer3DModel": lambda model_cls, weights: weights,
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
"HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
}
@@ -79,33 +81,6 @@ def _maybe_raise_error_for_ambiguity(config):
)
def _maybe_adjust_config_for_control_lora(config):
"""
"""
target_modules_before = config["target_modules"]
target_modules = []
modules_to_save = []
for module in target_modules_before:
if module.endswith("weight"):
base_name = ".".join(module.split(".")[:-1])
modules_to_save.append(base_name)
elif module.endswith("bias"):
base_name = ".".join(module.split(".")[:-1])
if ".".join([base_name, "weight"]) in target_modules_before:
modules_to_save.append(base_name)
else:
target_modules.append(base_name)
else:
target_modules.append(module)
config["target_modules"] = list(set(target_modules))
config["modules_to_save"] = list(set(modules_to_save))
return config
class PeftAdapterMixin:
"""
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
@@ -256,7 +231,7 @@ class PeftAdapterMixin:
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
if prefix is not None:
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
if len(state_dict) > 0:
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
@@ -269,13 +244,6 @@ class PeftAdapterMixin:
"Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping."
)
# Control LoRA from SAI is different from BFL Control LoRA
# https://huggingface.co/stabilityai/control-lora/
is_control_lora = "lora_controlnet" in state_dict
if is_control_lora:
del state_dict["lora_controlnet"]
state_dict = convert_control_lora_state_dict_to_peft(state_dict)
# check with first key if is not in peft format
first_key = next(iter(state_dict.keys()))
if "lora_A" not in first_key:
@@ -294,12 +262,12 @@ class PeftAdapterMixin:
if network_alphas is not None and len(network_alphas) >= 1:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
network_alphas = {
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
if is_control_lora:
lora_config_kwargs = _maybe_adjust_config_for_control_lora(lora_config_kwargs)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
@@ -365,7 +333,7 @@ class PeftAdapterMixin:
new_sd[k] = v
return new_sd
# To handle scenarios where we cannot successfully set state dict. If it's unsucessful,
# To handle scenarios where we cannot successfully set state dict. If it's unsuccessful,
# we should also delete the `peft_config` associated to the `adapter_name`.
try:
if hotswap:
@@ -379,7 +347,7 @@ class PeftAdapterMixin:
config=lora_config,
)
except Exception as e:
logger.error(f"Hotswapping {adapter_name} was unsucessful with the following error: \n{e}")
logger.error(f"Hotswapping {adapter_name} was unsuccessful with the following error: \n{e}")
raise
# the hotswap function raises if there are incompatible keys, so if we reach this point we can set
# it to None
@@ -414,7 +382,7 @@ class PeftAdapterMixin:
module.delete_adapter(adapter_name)
self.peft_config.pop(adapter_name)
logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}")
logger.error(f"Loading {adapter_name} was unsuccessful with the following error: \n{e}")
raise
warn_msg = ""
@@ -747,7 +715,7 @@ class PeftAdapterMixin:
if self.lora_scale != 1.0:
module.scale_layer(self.lora_scale)
# For BC with prevous PEFT versions, we need to check the signature
# For BC with previous PEFT versions, we need to check the signature
# of the `merge` method to see if it supports the `adapter_names` argument.
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
if "adapter_names" in supported_merge_kwargs: