1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
lavinal712
2025-11-22 09:50:45 +08:00
parent dfad05625e
commit 4d1e8912d6
3 changed files with 7 additions and 7 deletions

View File

@@ -26,7 +26,7 @@ from ..utils import (
MIN_PEFT_VERSION,
USE_PEFT_BACKEND,
check_peft_version,
convert_control_lora_state_dict_to_peft,
convert_sai_sd_control_lora_state_dict_to_peft,
convert_unet_state_dict_to_peft,
delete_adapter_layers,
get_adapter_name,
@@ -233,9 +233,9 @@ class PeftAdapterMixin:
# Control LoRA from SAI is different from BFL Control LoRA
# https://huggingface.co/stabilityai/control-lora
# https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors
is_control_lora = "lora_controlnet" in state_dict
if is_control_lora:
state_dict = convert_control_lora_state_dict_to_peft(state_dict)
is_sai_sd_control_lora = "lora_controlnet" in state_dict
if is_sai_sd_control_lora:
state_dict = convert_sai_sd_control_lora_state_dict_to_peft(state_dict)
rank = {}
for key, val in state_dict.items():
@@ -269,7 +269,7 @@ class PeftAdapterMixin:
)
# Adjust LoRA config for Control LoRA
if is_control_lora:
if is_sai_sd_control_lora:
lora_config.lora_alpha = lora_config.r
lora_config.alpha_pattern = lora_config.rank_pattern
lora_config.bias = "all"

View File

@@ -143,7 +143,7 @@ from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_p
from .remote_utils import remote_decode
from .state_dict_utils import (
convert_all_state_dict_to_peft,
convert_control_lora_state_dict_to_peft,
convert_sai_sd_control_lora_state_dict_to_peft,
convert_state_dict_to_diffusers,
convert_state_dict_to_kohya,
convert_state_dict_to_peft,

View File

@@ -289,7 +289,7 @@ def convert_unet_state_dict_to_peft(state_dict):
return convert_state_dict(state_dict, mapping)
def convert_control_lora_state_dict_to_peft(state_dict):
def convert_sai_sd_control_lora_state_dict_to_peft(state_dict):
def _convert_controlnet_to_diffusers(state_dict):
is_sdxl = "input_blocks.11.0.in_layers.0.weight" not in state_dict
logger.info(f"Using ControlNet lora ({'SDXL' if is_sdxl else 'SD15'})")