mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
merged but bug
This commit is contained in:
@@ -25,6 +25,7 @@ from ..utils import (
|
||||
MIN_PEFT_VERSION,
|
||||
USE_PEFT_BACKEND,
|
||||
check_peft_version,
|
||||
convert_control_lora_state_dict_to_peft,
|
||||
convert_unet_state_dict_to_peft,
|
||||
delete_adapter_layers,
|
||||
get_adapter_name,
|
||||
@@ -78,6 +79,33 @@ def _maybe_raise_error_for_ambiguity(config):
|
||||
)
|
||||
|
||||
|
||||
def _maybe_adjust_config_for_control_lora(config):
|
||||
"""
|
||||
"""
|
||||
|
||||
target_modules_before = config["target_modules"]
|
||||
target_modules = []
|
||||
modules_to_save = []
|
||||
|
||||
for module in target_modules_before:
|
||||
if module.endswith("weight"):
|
||||
base_name = ".".join(module.split(".")[:-1])
|
||||
modules_to_save.append(base_name)
|
||||
elif module.endswith("bias"):
|
||||
base_name = ".".join(module.split(".")[:-1])
|
||||
if ".".join([base_name, "weight"]) in target_modules_before:
|
||||
modules_to_save.append(base_name)
|
||||
else:
|
||||
target_modules.append(base_name)
|
||||
else:
|
||||
target_modules.append(module)
|
||||
|
||||
config["target_modules"] = list(set(target_modules))
|
||||
config["modules_to_save"] = list(set(modules_to_save))
|
||||
|
||||
return config
|
||||
|
||||
|
||||
class PeftAdapterMixin:
|
||||
"""
|
||||
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
|
||||
@@ -241,6 +269,13 @@ class PeftAdapterMixin:
|
||||
"Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping."
|
||||
)
|
||||
|
||||
# Control LoRA from SAI is different from BFL Control LoRA
|
||||
# https://huggingface.co/stabilityai/control-lora/
|
||||
is_control_lora = "lora_controlnet" in state_dict
|
||||
if is_control_lora:
|
||||
del state_dict["lora_controlnet"]
|
||||
state_dict = convert_control_lora_state_dict_to_peft(state_dict)
|
||||
|
||||
# check with first key if is not in peft format
|
||||
first_key = next(iter(state_dict.keys()))
|
||||
if "lora_A" not in first_key:
|
||||
@@ -263,6 +298,8 @@ class PeftAdapterMixin:
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
|
||||
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
|
||||
if is_control_lora:
|
||||
lora_config_kwargs = _maybe_adjust_config_for_control_lora(lora_config_kwargs)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
|
||||
@@ -122,6 +122,7 @@ from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_p
|
||||
from .remote_utils import remote_decode
|
||||
from .state_dict_utils import (
|
||||
convert_all_state_dict_to_peft,
|
||||
convert_control_lora_state_dict_to_peft,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_kohya,
|
||||
convert_state_dict_to_peft,
|
||||
|
||||
Reference in New Issue
Block a user