1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Support for control-lora (#10686)

* run control-lora on diffusers

* cannot load lora adapter

* test

* 1

* add control-lora

* 1

* 1

* 1

* fix PeftAdapterMixin

* fix module_to_save bug

* delete json print

* resolve conflits

* merged but bug

* change peft.py

* 1

* delete state_dict print

* fix alpha

* Create control_lora.py

* Add files via upload

* rename

* no need modify as peft updated

* add doc

* fix code style

* styling isn't that hard 😉

* empty

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Yuqian Hong
2025-12-15 18:22:42 +08:00
committed by GitHub
parent 0c1ccc0775
commit 58519283e7
7 changed files with 312 additions and 1 deletions

View File

@@ -33,6 +33,21 @@ url = "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/m
pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet)
```
## Loading from Control LoRA
Control-LoRA is introduced by Stability AI in [stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora) by adding low-rank parameter efficient fine tuning to ControlNet. This approach offers a more efficient and compact method to bring model control to a wider variety of consumer GPUs.
```py
from diffusers import ControlNetModel, UNet2DConditionModel
lora_id = "stabilityai/control-lora"
lora_filename = "control-LoRAs-rank128/control-lora-canny-rank128.safetensors"
unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", torch_dtype=torch.bfloat16).to("cuda")
controlnet = ControlNetModel.from_unet(unet).to(device="cuda", dtype=torch.bfloat16)
controlnet.load_lora_adapter(lora_id, weight_name=lora_filename, prefix=None, controlnet_config=controlnet.config)
```
## ControlNetModel
[[autodoc]] ControlNetModel

View File

@@ -0,0 +1,41 @@
# Control-LoRA inference example
Control-LoRA is introduced by Stability AI in [stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora) by adding low-rank parameter efficient fine tuning to ControlNet. This approach offers a more efficient and compact method to bring model control to a wider variety of consumer GPUs.
## Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
```
Then cd in the example folder and run
```bash
pip install -r requirements.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
## Inference on SDXL
[stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora) provides a set of Control-LoRA weights for SDXL. Here we use the `canny` condition to generate an image from a text prompt and a reference image.
```bash
python control_lora.py
```
## Acknowledgements
- [stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora)
- [comfyanonymous/ControlNet-v1-1_fp16_safetensors](https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors)
- [HighCWu/control-lora-v2](https://github.com/HighCWu/control-lora-v2)

View File

@@ -0,0 +1,58 @@
import cv2
import numpy as np
import torch
from PIL import Image
from diffusers import (
AutoencoderKL,
ControlNetModel,
StableDiffusionXLControlNetPipeline,
UNet2DConditionModel,
)
from diffusers.utils import load_image, make_image_grid
pipe_id = "stabilityai/stable-diffusion-xl-base-1.0"
lora_id = "stabilityai/control-lora"
lora_filename = "control-LoRAs-rank128/control-lora-canny-rank128.safetensors"
unet = UNet2DConditionModel.from_pretrained(pipe_id, subfolder="unet", torch_dtype=torch.bfloat16).to("cuda")
controlnet = ControlNetModel.from_unet(unet).to(device="cuda", dtype=torch.bfloat16)
controlnet.load_lora_adapter(lora_id, weight_name=lora_filename, prefix=None, controlnet_config=controlnet.config)
prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
negative_prompt = "low quality, bad quality, sketches"
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
)
controlnet_conditioning_scale = 1.0 # recommended for good generalization
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", torch_dtype=torch.bfloat16)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
pipe_id,
unet=unet,
controlnet=controlnet,
vae=vae,
torch_dtype=torch.bfloat16,
safety_checker=None,
).to("cuda")
image = np.array(image)
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
image = Image.fromarray(image)
images = pipe(
prompt,
negative_prompt=negative_prompt,
image=image,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_images_per_prompt=4,
).images
final_image = [image] + images
grid = make_image_grid(final_image, 1, 5)
grid.save("hf-logo_canny.png")

View File

@@ -27,6 +27,7 @@ from ..utils import (
MIN_PEFT_VERSION,
USE_PEFT_BACKEND,
check_peft_version,
convert_sai_sd_control_lora_state_dict_to_peft,
convert_unet_state_dict_to_peft,
delete_adapter_layers,
get_adapter_name,
@@ -232,6 +233,13 @@ class PeftAdapterMixin:
if "lora_A" not in first_key:
state_dict = convert_unet_state_dict_to_peft(state_dict)
# Control LoRA from SAI is different from BFL Control LoRA
# https://huggingface.co/stabilityai/control-lora
# https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors
is_sai_sd_control_lora = "lora_controlnet" in state_dict
if is_sai_sd_control_lora:
state_dict = convert_sai_sd_control_lora_state_dict_to_peft(state_dict)
rank = {}
for key, val in state_dict.items():
# Cannot figure out rank from lora layers that don't have at least 2 dimensions.
@@ -263,6 +271,14 @@ class PeftAdapterMixin:
adapter_name=adapter_name,
)
# Adjust LoRA config for Control LoRA
if is_sai_sd_control_lora:
lora_config.lora_alpha = lora_config.r
lora_config.alpha_pattern = lora_config.rank_pattern
lora_config.bias = "all"
lora_config.modules_to_save = lora_config.exclude_modules
lora_config.exclude_modules = None
# <Unsafe code
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
# Now we remove any existing hooks to `_pipeline`.

View File

@@ -19,6 +19,7 @@ from torch import nn
from torch.nn import functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import BaseOutput, logging
from ..attention import AttentionMixin
@@ -106,7 +107,7 @@ class ControlNetConditioningEmbedding(nn.Module):
return embedding
class ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, FromOriginalModelMixin):
class ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
"""
A ControlNet model.

View File

@@ -143,6 +143,7 @@ from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_p
from .remote_utils import remote_decode
from .state_dict_utils import (
convert_all_state_dict_to_peft,
convert_sai_sd_control_lora_state_dict_to_peft,
convert_state_dict_to_diffusers,
convert_state_dict_to_kohya,
convert_state_dict_to_peft,

View File

@@ -56,6 +56,36 @@ UNET_TO_DIFFUSERS = {
".to_out.lora_magnitude_vector": ".to_out.0.lora_magnitude_vector",
}
CONTROL_LORA_TO_DIFFUSERS = {
".to_q.down": ".to_q.lora_A.weight",
".to_q.up": ".to_q.lora_B.weight",
".to_k.down": ".to_k.lora_A.weight",
".to_k.up": ".to_k.lora_B.weight",
".to_v.down": ".to_v.lora_A.weight",
".to_v.up": ".to_v.lora_B.weight",
".to_out.0.down": ".to_out.0.lora_A.weight",
".to_out.0.up": ".to_out.0.lora_B.weight",
".ff.net.0.proj.down": ".ff.net.0.proj.lora_A.weight",
".ff.net.0.proj.up": ".ff.net.0.proj.lora_B.weight",
".ff.net.2.down": ".ff.net.2.lora_A.weight",
".ff.net.2.up": ".ff.net.2.lora_B.weight",
".proj_in.down": ".proj_in.lora_A.weight",
".proj_in.up": ".proj_in.lora_B.weight",
".proj_out.down": ".proj_out.lora_A.weight",
".proj_out.up": ".proj_out.lora_B.weight",
".conv.down": ".conv.lora_A.weight",
".conv.up": ".conv.lora_B.weight",
**{f".conv{i}.down": f".conv{i}.lora_A.weight" for i in range(1, 3)},
**{f".conv{i}.up": f".conv{i}.lora_B.weight" for i in range(1, 3)},
"conv_in.down": "conv_in.lora_A.weight",
"conv_in.up": "conv_in.lora_B.weight",
".conv_shortcut.down": ".conv_shortcut.lora_A.weight",
".conv_shortcut.up": ".conv_shortcut.lora_B.weight",
**{f".linear_{i}.down": f".linear_{i}.lora_A.weight" for i in range(1, 3)},
**{f".linear_{i}.up": f".linear_{i}.lora_B.weight" for i in range(1, 3)},
"time_emb_proj.down": "time_emb_proj.lora_A.weight",
"time_emb_proj.up": "time_emb_proj.lora_B.weight",
}
DIFFUSERS_TO_PEFT = {
".q_proj.lora_linear_layer.up": ".q_proj.lora_B",
@@ -259,6 +289,155 @@ def convert_unet_state_dict_to_peft(state_dict):
return convert_state_dict(state_dict, mapping)
def convert_sai_sd_control_lora_state_dict_to_peft(state_dict):
def _convert_controlnet_to_diffusers(state_dict):
is_sdxl = "input_blocks.11.0.in_layers.0.weight" not in state_dict
logger.info(f"Using ControlNet lora ({'SDXL' if is_sdxl else 'SD15'})")
# Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in state_dict if "input_blocks" in layer})
input_blocks = {
layer_id: [key for key in state_dict if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}
layers_per_block = 2
# op blocks
op_blocks = [key for key in state_dict if "0.op" in key]
converted_state_dict = {}
# Conv in layers
for key in input_blocks[0]:
diffusers_key = key.replace("input_blocks.0.0", "conv_in")
converted_state_dict[diffusers_key] = state_dict.get(key)
# controlnet time embedding blocks
time_embedding_blocks = [key for key in state_dict if "time_embed" in key]
for key in time_embedding_blocks:
diffusers_key = key.replace("time_embed.0", "time_embedding.linear_1").replace(
"time_embed.2", "time_embedding.linear_2"
)
converted_state_dict[diffusers_key] = state_dict.get(key)
# controlnet label embedding blocks
label_embedding_blocks = [key for key in state_dict if "label_emb" in key]
for key in label_embedding_blocks:
diffusers_key = key.replace("label_emb.0.0", "add_embedding.linear_1").replace(
"label_emb.0.2", "add_embedding.linear_2"
)
converted_state_dict[diffusers_key] = state_dict.get(key)
# Down blocks
for i in range(1, num_input_blocks):
block_id = (i - 1) // (layers_per_block + 1)
layer_in_block_id = (i - 1) % (layers_per_block + 1)
resnets = [
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
]
for key in resnets:
diffusers_key = (
key.replace("in_layers.0", "norm1")
.replace("in_layers.2", "conv1")
.replace("out_layers.0", "norm2")
.replace("out_layers.3", "conv2")
.replace("emb_layers.1", "time_emb_proj")
.replace("skip_connection", "conv_shortcut")
)
diffusers_key = diffusers_key.replace(
f"input_blocks.{i}.0", f"down_blocks.{block_id}.resnets.{layer_in_block_id}"
)
converted_state_dict[diffusers_key] = state_dict.get(key)
if f"input_blocks.{i}.0.op.bias" in state_dict:
for key in [key for key in op_blocks if f"input_blocks.{i}.0.op" in key]:
diffusers_key = key.replace(
f"input_blocks.{i}.0.op", f"down_blocks.{block_id}.downsamplers.0.conv"
)
converted_state_dict[diffusers_key] = state_dict.get(key)
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if attentions:
for key in attentions:
diffusers_key = key.replace(
f"input_blocks.{i}.1", f"down_blocks.{block_id}.attentions.{layer_in_block_id}"
)
converted_state_dict[diffusers_key] = state_dict.get(key)
# controlnet down blocks
for i in range(num_input_blocks):
converted_state_dict[f"controlnet_down_blocks.{i}.weight"] = state_dict.get(f"zero_convs.{i}.0.weight")
converted_state_dict[f"controlnet_down_blocks.{i}.bias"] = state_dict.get(f"zero_convs.{i}.0.bias")
# Retrieves the keys for the middle blocks only
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in state_dict if "middle_block" in layer})
middle_blocks = {
layer_id: [key for key in state_dict if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks)
}
# Mid blocks
for key in middle_blocks.keys():
diffusers_key = max(key - 1, 0)
if key % 2 == 0:
for k in middle_blocks[key]:
diffusers_key_hf = (
k.replace("in_layers.0", "norm1")
.replace("in_layers.2", "conv1")
.replace("out_layers.0", "norm2")
.replace("out_layers.3", "conv2")
.replace("emb_layers.1", "time_emb_proj")
.replace("skip_connection", "conv_shortcut")
)
diffusers_key_hf = diffusers_key_hf.replace(
f"middle_block.{key}", f"mid_block.resnets.{diffusers_key}"
)
converted_state_dict[diffusers_key_hf] = state_dict.get(k)
else:
for k in middle_blocks[key]:
diffusers_key_hf = k.replace(f"middle_block.{key}", f"mid_block.attentions.{diffusers_key}")
converted_state_dict[diffusers_key_hf] = state_dict.get(k)
# mid block
converted_state_dict["controlnet_mid_block.weight"] = state_dict.get("middle_block_out.0.weight")
converted_state_dict["controlnet_mid_block.bias"] = state_dict.get("middle_block_out.0.bias")
# controlnet cond embedding blocks
cond_embedding_blocks = {
".".join(layer.split(".")[:2])
for layer in state_dict
if "input_hint_block" in layer
and ("input_hint_block.0" not in layer)
and ("input_hint_block.14" not in layer)
}
num_cond_embedding_blocks = len(cond_embedding_blocks)
for idx in range(1, num_cond_embedding_blocks + 1):
diffusers_idx = idx - 1
cond_block_id = 2 * idx
converted_state_dict[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = state_dict.get(
f"input_hint_block.{cond_block_id}.weight"
)
converted_state_dict[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = state_dict.get(
f"input_hint_block.{cond_block_id}.bias"
)
for key in [key for key in state_dict if "input_hint_block.0" in key]:
diffusers_key = key.replace("input_hint_block.0", "controlnet_cond_embedding.conv_in")
converted_state_dict[diffusers_key] = state_dict.get(key)
for key in [key for key in state_dict if "input_hint_block.14" in key]:
diffusers_key = key.replace("input_hint_block.14", "controlnet_cond_embedding.conv_out")
converted_state_dict[diffusers_key] = state_dict.get(key)
return converted_state_dict
state_dict = _convert_controlnet_to_diffusers(state_dict)
mapping = CONTROL_LORA_TO_DIFFUSERS
return convert_state_dict(state_dict, mapping)
def convert_all_state_dict_to_peft(state_dict):
r"""
Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer` for a valid