mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
up
This commit is contained in:
@@ -1,16 +1,17 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from diffusers import (
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
AutoencoderKL,
|
||||
ControlNetModel,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.utils import load_image, make_image_grid
|
||||
|
||||
|
||||
pipe_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
lora_id = "stabilityai/control-lora"
|
||||
lora_filename = "control-LoRAs-rank128/control-lora-canny-rank128.safetensors"
|
||||
@@ -22,7 +23,9 @@ controlnet.load_lora_adapter(lora_id, weight_name=lora_filename, prefix=None, co
|
||||
prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
|
||||
negative_prompt = "low quality, bad quality, sketches"
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png")
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
|
||||
)
|
||||
|
||||
controlnet_conditioning_scale = 1.0 # recommended for good generalization
|
||||
|
||||
@@ -43,9 +46,11 @@ image = np.concatenate([image, image, image], axis=2)
|
||||
image = Image.fromarray(image)
|
||||
|
||||
images = pipe(
|
||||
prompt, negative_prompt=negative_prompt, image=image,
|
||||
prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
image=image,
|
||||
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
||||
num_images_per_prompt=4
|
||||
num_images_per_prompt=4,
|
||||
).images
|
||||
|
||||
final_image = [image] + images
|
||||
|
||||
@@ -293,11 +293,9 @@ def convert_sai_sd_control_lora_state_dict_to_peft(state_dict):
|
||||
def _convert_controlnet_to_diffusers(state_dict):
|
||||
is_sdxl = "input_blocks.11.0.in_layers.0.weight" not in state_dict
|
||||
logger.info(f"Using ControlNet lora ({'SDXL' if is_sdxl else 'SD15'})")
|
||||
|
||||
|
||||
# Retrieves the keys for the input blocks only
|
||||
num_input_blocks = len(
|
||||
{".".join(layer.split(".")[:2]) for layer in state_dict if "input_blocks" in layer}
|
||||
)
|
||||
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in state_dict if "input_blocks" in layer})
|
||||
input_blocks = {
|
||||
layer_id: [key for key in state_dict if f"input_blocks.{layer_id}" in key]
|
||||
for layer_id in range(num_input_blocks)
|
||||
@@ -312,20 +310,20 @@ def convert_sai_sd_control_lora_state_dict_to_peft(state_dict):
|
||||
for key in input_blocks[0]:
|
||||
diffusers_key = key.replace("input_blocks.0.0", "conv_in")
|
||||
converted_state_dict[diffusers_key] = state_dict.get(key)
|
||||
|
||||
|
||||
# controlnet time embedding blocks
|
||||
time_embedding_blocks = [key for key in state_dict if "time_embed" in key]
|
||||
for key in time_embedding_blocks:
|
||||
diffusers_key = (key.replace("time_embed.0", "time_embedding.linear_1")
|
||||
.replace("time_embed.2", "time_embedding.linear_2")
|
||||
diffusers_key = key.replace("time_embed.0", "time_embedding.linear_1").replace(
|
||||
"time_embed.2", "time_embedding.linear_2"
|
||||
)
|
||||
converted_state_dict[diffusers_key] = state_dict.get(key)
|
||||
|
||||
# controlnet label embedding blocks
|
||||
label_embedding_blocks = [key for key in state_dict if "label_emb" in key]
|
||||
for key in label_embedding_blocks:
|
||||
diffusers_key = (key.replace("label_emb.0.0", "add_embedding.linear_1")
|
||||
.replace("label_emb.0.2", "add_embedding.linear_2")
|
||||
diffusers_key = key.replace("label_emb.0.0", "add_embedding.linear_1").replace(
|
||||
"label_emb.0.2", "add_embedding.linear_2"
|
||||
)
|
||||
converted_state_dict[diffusers_key] = state_dict.get(key)
|
||||
|
||||
@@ -338,7 +336,8 @@ def convert_sai_sd_control_lora_state_dict_to_peft(state_dict):
|
||||
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
||||
]
|
||||
for key in resnets:
|
||||
diffusers_key = (key.replace("in_layers.0", "norm1")
|
||||
diffusers_key = (
|
||||
key.replace("in_layers.0", "norm1")
|
||||
.replace("in_layers.2", "conv1")
|
||||
.replace("out_layers.0", "norm2")
|
||||
.replace("out_layers.3", "conv2")
|
||||
@@ -352,7 +351,9 @@ def convert_sai_sd_control_lora_state_dict_to_peft(state_dict):
|
||||
|
||||
if f"input_blocks.{i}.0.op.bias" in state_dict:
|
||||
for key in [key for key in op_blocks if f"input_blocks.{i}.0.op" in key]:
|
||||
diffusers_key = key.replace(f"input_blocks.{i}.0.op", f"down_blocks.{block_id}.downsamplers.0.conv")
|
||||
diffusers_key = key.replace(
|
||||
f"input_blocks.{i}.0.op", f"down_blocks.{block_id}.downsamplers.0.conv"
|
||||
)
|
||||
converted_state_dict[diffusers_key] = state_dict.get(key)
|
||||
|
||||
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
||||
@@ -362,16 +363,14 @@ def convert_sai_sd_control_lora_state_dict_to_peft(state_dict):
|
||||
f"input_blocks.{i}.1", f"down_blocks.{block_id}.attentions.{layer_in_block_id}"
|
||||
)
|
||||
converted_state_dict[diffusers_key] = state_dict.get(key)
|
||||
|
||||
|
||||
# controlnet down blocks
|
||||
for i in range(num_input_blocks):
|
||||
converted_state_dict[f"controlnet_down_blocks.{i}.weight"] = state_dict.get(f"zero_convs.{i}.0.weight")
|
||||
converted_state_dict[f"controlnet_down_blocks.{i}.bias"] = state_dict.get(f"zero_convs.{i}.0.bias")
|
||||
|
||||
# Retrieves the keys for the middle blocks only
|
||||
num_middle_blocks = len(
|
||||
{".".join(layer.split(".")[:2]) for layer in state_dict if "middle_block" in layer}
|
||||
)
|
||||
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in state_dict if "middle_block" in layer})
|
||||
middle_blocks = {
|
||||
layer_id: [key for key in state_dict if f"middle_block.{layer_id}" in key]
|
||||
for layer_id in range(num_middle_blocks)
|
||||
@@ -382,7 +381,8 @@ def convert_sai_sd_control_lora_state_dict_to_peft(state_dict):
|
||||
diffusers_key = max(key - 1, 0)
|
||||
if key % 2 == 0:
|
||||
for k in middle_blocks[key]:
|
||||
diffusers_key_hf = (k.replace("in_layers.0", "norm1")
|
||||
diffusers_key_hf = (
|
||||
k.replace("in_layers.0", "norm1")
|
||||
.replace("in_layers.2", "conv1")
|
||||
.replace("out_layers.0", "norm2")
|
||||
.replace("out_layers.3", "conv2")
|
||||
@@ -395,11 +395,9 @@ def convert_sai_sd_control_lora_state_dict_to_peft(state_dict):
|
||||
converted_state_dict[diffusers_key_hf] = state_dict.get(k)
|
||||
else:
|
||||
for k in middle_blocks[key]:
|
||||
diffusers_key_hf = k.replace(
|
||||
f"middle_block.{key}", f"mid_block.attentions.{diffusers_key}"
|
||||
)
|
||||
diffusers_key_hf = k.replace(f"middle_block.{key}", f"mid_block.attentions.{diffusers_key}")
|
||||
converted_state_dict[diffusers_key_hf] = state_dict.get(k)
|
||||
|
||||
|
||||
# mid block
|
||||
converted_state_dict["controlnet_mid_block.weight"] = state_dict.get("middle_block_out.0.weight")
|
||||
converted_state_dict["controlnet_mid_block.bias"] = state_dict.get("middle_block_out.0.bias")
|
||||
@@ -408,7 +406,9 @@ def convert_sai_sd_control_lora_state_dict_to_peft(state_dict):
|
||||
cond_embedding_blocks = {
|
||||
".".join(layer.split(".")[:2])
|
||||
for layer in state_dict
|
||||
if "input_hint_block" in layer and ("input_hint_block.0" not in layer) and ("input_hint_block.14" not in layer)
|
||||
if "input_hint_block" in layer
|
||||
and ("input_hint_block.0" not in layer)
|
||||
and ("input_hint_block.14" not in layer)
|
||||
}
|
||||
num_cond_embedding_blocks = len(cond_embedding_blocks)
|
||||
|
||||
@@ -422,13 +422,13 @@ def convert_sai_sd_control_lora_state_dict_to_peft(state_dict):
|
||||
converted_state_dict[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = state_dict.get(
|
||||
f"input_hint_block.{cond_block_id}.bias"
|
||||
)
|
||||
|
||||
|
||||
for key in [key for key in state_dict if "input_hint_block.0" in key]:
|
||||
diffusers_key = key.replace("input_hint_block.0", "controlnet_cond_embedding.conv_in")
|
||||
converted_state_dict[diffusers_key] = state_dict.get(key)
|
||||
|
||||
|
||||
for key in [key for key in state_dict if "input_hint_block.14" in key]:
|
||||
diffusers_key = key.replace(f"input_hint_block.14", "controlnet_cond_embedding.conv_out")
|
||||
diffusers_key = key.replace("input_hint_block.14", "controlnet_cond_embedding.conv_out")
|
||||
converted_state_dict[diffusers_key] = state_dict.get(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
Reference in New Issue
Block a user