From 6fff794e59a4d09bb5d6848eaebe297cbcb3c0ec Mon Sep 17 00:00:00 2001 From: lavinal712 Date: Wed, 9 Apr 2025 07:56:40 +0000 Subject: [PATCH] merged but bug --- src/diffusers/loaders/peft.py | 37 +++++++++++++++++++++++++++++++++ src/diffusers/utils/__init__.py | 1 + 2 files changed, 38 insertions(+) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 9165c46f3c..0280fc23f7 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -25,6 +25,7 @@ from ..utils import ( MIN_PEFT_VERSION, USE_PEFT_BACKEND, check_peft_version, + convert_control_lora_state_dict_to_peft, convert_unet_state_dict_to_peft, delete_adapter_layers, get_adapter_name, @@ -78,6 +79,33 @@ def _maybe_raise_error_for_ambiguity(config): ) +def _maybe_adjust_config_for_control_lora(config): + """ + """ + + target_modules_before = config["target_modules"] + target_modules = [] + modules_to_save = [] + + for module in target_modules_before: + if module.endswith("weight"): + base_name = ".".join(module.split(".")[:-1]) + modules_to_save.append(base_name) + elif module.endswith("bias"): + base_name = ".".join(module.split(".")[:-1]) + if ".".join([base_name, "weight"]) in target_modules_before: + modules_to_save.append(base_name) + else: + target_modules.append(base_name) + else: + target_modules.append(module) + + config["target_modules"] = list(set(target_modules)) + config["modules_to_save"] = list(set(modules_to_save)) + + return config + + class PeftAdapterMixin: """ A class containing all functions for loading and using adapters weights that are supported in PEFT library. For @@ -241,6 +269,13 @@ class PeftAdapterMixin: "Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping." ) + # Control LoRA from SAI is different from BFL Control LoRA + # https://huggingface.co/stabilityai/control-lora/ + is_control_lora = "lora_controlnet" in state_dict + if is_control_lora: + del state_dict["lora_controlnet"] + state_dict = convert_control_lora_state_dict_to_peft(state_dict) + # check with first key if is not in peft format first_key = next(iter(state_dict.keys())) if "lora_A" not in first_key: @@ -263,6 +298,8 @@ class PeftAdapterMixin: lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) _maybe_raise_error_for_ambiguity(lora_config_kwargs) + if is_control_lora: + lora_config_kwargs = _maybe_adjust_config_for_control_lora(lora_config_kwargs) if "use_dora" in lora_config_kwargs: if lora_config_kwargs["use_dora"]: diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 438faa23e5..777cfec714 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -122,6 +122,7 @@ from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_p from .remote_utils import remote_decode from .state_dict_utils import ( convert_all_state_dict_to_peft, + convert_control_lora_state_dict_to_peft, convert_state_dict_to_diffusers, convert_state_dict_to_kohya, convert_state_dict_to_peft,