diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 81d9b00fa6..723ad9707d 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -26,7 +26,7 @@ from ..utils import ( MIN_PEFT_VERSION, USE_PEFT_BACKEND, check_peft_version, - convert_control_lora_state_dict_to_peft, + convert_sai_sd_control_lora_state_dict_to_peft, convert_unet_state_dict_to_peft, delete_adapter_layers, get_adapter_name, @@ -233,9 +233,9 @@ class PeftAdapterMixin: # Control LoRA from SAI is different from BFL Control LoRA # https://huggingface.co/stabilityai/control-lora # https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors - is_control_lora = "lora_controlnet" in state_dict - if is_control_lora: - state_dict = convert_control_lora_state_dict_to_peft(state_dict) + is_sai_sd_control_lora = "lora_controlnet" in state_dict + if is_sai_sd_control_lora: + state_dict = convert_sai_sd_control_lora_state_dict_to_peft(state_dict) rank = {} for key, val in state_dict.items(): @@ -269,7 +269,7 @@ class PeftAdapterMixin: ) # Adjust LoRA config for Control LoRA - if is_control_lora: + if is_sai_sd_control_lora: lora_config.lora_alpha = lora_config.r lora_config.alpha_pattern = lora_config.rank_pattern lora_config.bias = "all" diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 3f19e3e8eb..440e4539e7 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -143,7 +143,7 @@ from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_p from .remote_utils import remote_decode from .state_dict_utils import ( convert_all_state_dict_to_peft, - convert_control_lora_state_dict_to_peft, + convert_sai_sd_control_lora_state_dict_to_peft, convert_state_dict_to_diffusers, convert_state_dict_to_kohya, convert_state_dict_to_peft, diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 14d5e9cfb6..8c3aa5807e 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -289,7 +289,7 @@ def convert_unet_state_dict_to_peft(state_dict): return convert_state_dict(state_dict, mapping) -def convert_control_lora_state_dict_to_peft(state_dict): +def convert_sai_sd_control_lora_state_dict_to_peft(state_dict): def _convert_controlnet_to_diffusers(state_dict): is_sdxl = "input_blocks.11.0.in_layers.0.weight" not in state_dict logger.info(f"Using ControlNet lora ({'SDXL' if is_sdxl else 'SD15'})")