1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

merged but bug

This commit is contained in:
lavinal712
2025-04-09 07:56:40 +00:00
parent ab9eeff757
commit 6fff794e59
2 changed files with 38 additions and 0 deletions

View File

@@ -25,6 +25,7 @@ from ..utils import (
MIN_PEFT_VERSION,
USE_PEFT_BACKEND,
check_peft_version,
convert_control_lora_state_dict_to_peft,
convert_unet_state_dict_to_peft,
delete_adapter_layers,
get_adapter_name,
@@ -78,6 +79,33 @@ def _maybe_raise_error_for_ambiguity(config):
)
def _maybe_adjust_config_for_control_lora(config):
"""
"""
target_modules_before = config["target_modules"]
target_modules = []
modules_to_save = []
for module in target_modules_before:
if module.endswith("weight"):
base_name = ".".join(module.split(".")[:-1])
modules_to_save.append(base_name)
elif module.endswith("bias"):
base_name = ".".join(module.split(".")[:-1])
if ".".join([base_name, "weight"]) in target_modules_before:
modules_to_save.append(base_name)
else:
target_modules.append(base_name)
else:
target_modules.append(module)
config["target_modules"] = list(set(target_modules))
config["modules_to_save"] = list(set(modules_to_save))
return config
class PeftAdapterMixin:
"""
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
@@ -241,6 +269,13 @@ class PeftAdapterMixin:
"Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping."
)
# Control LoRA from SAI is different from BFL Control LoRA
# https://huggingface.co/stabilityai/control-lora/
is_control_lora = "lora_controlnet" in state_dict
if is_control_lora:
del state_dict["lora_controlnet"]
state_dict = convert_control_lora_state_dict_to_peft(state_dict)
# check with first key if is not in peft format
first_key = next(iter(state_dict.keys()))
if "lora_A" not in first_key:
@@ -263,6 +298,8 @@ class PeftAdapterMixin:
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
if is_control_lora:
lora_config_kwargs = _maybe_adjust_config_for_control_lora(lora_config_kwargs)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:

View File

@@ -122,6 +122,7 @@ from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_p
from .remote_utils import remote_decode
from .state_dict_utils import (
convert_all_state_dict_to_peft,
convert_control_lora_state_dict_to_peft,
convert_state_dict_to_diffusers,
convert_state_dict_to_kohya,
convert_state_dict_to_peft,