diff --git a/examples/research_projects/control_lora/control_lora.py b/examples/research_projects/control_lora/control_lora.py index 435c9c945b..a0ad1981c7 100644 --- a/examples/research_projects/control_lora/control_lora.py +++ b/examples/research_projects/control_lora/control_lora.py @@ -1,16 +1,17 @@ import cv2 import numpy as np -from PIL import Image import torch +from PIL import Image from diffusers import ( - StableDiffusionXLControlNetPipeline, + AutoencoderKL, ControlNetModel, + StableDiffusionXLControlNetPipeline, UNet2DConditionModel, ) -from diffusers import AutoencoderKL from diffusers.utils import load_image, make_image_grid + pipe_id = "stabilityai/stable-diffusion-xl-base-1.0" lora_id = "stabilityai/control-lora" lora_filename = "control-LoRAs-rank128/control-lora-canny-rank128.safetensors" @@ -22,7 +23,9 @@ controlnet.load_lora_adapter(lora_id, weight_name=lora_filename, prefix=None, co prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" negative_prompt = "low quality, bad quality, sketches" -image = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png") +image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" +) controlnet_conditioning_scale = 1.0 # recommended for good generalization @@ -43,9 +46,11 @@ image = np.concatenate([image, image, image], axis=2) image = Image.fromarray(image) images = pipe( - prompt, negative_prompt=negative_prompt, image=image, + prompt, + negative_prompt=negative_prompt, + image=image, controlnet_conditioning_scale=controlnet_conditioning_scale, - num_images_per_prompt=4 + num_images_per_prompt=4, ).images final_image = [image] + images diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 8c3aa5807e..c9bf5fec28 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -293,11 +293,9 @@ def convert_sai_sd_control_lora_state_dict_to_peft(state_dict): def _convert_controlnet_to_diffusers(state_dict): is_sdxl = "input_blocks.11.0.in_layers.0.weight" not in state_dict logger.info(f"Using ControlNet lora ({'SDXL' if is_sdxl else 'SD15'})") - + # Retrieves the keys for the input blocks only - num_input_blocks = len( - {".".join(layer.split(".")[:2]) for layer in state_dict if "input_blocks" in layer} - ) + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in state_dict if "input_blocks" in layer}) input_blocks = { layer_id: [key for key in state_dict if f"input_blocks.{layer_id}" in key] for layer_id in range(num_input_blocks) @@ -312,20 +310,20 @@ def convert_sai_sd_control_lora_state_dict_to_peft(state_dict): for key in input_blocks[0]: diffusers_key = key.replace("input_blocks.0.0", "conv_in") converted_state_dict[diffusers_key] = state_dict.get(key) - + # controlnet time embedding blocks time_embedding_blocks = [key for key in state_dict if "time_embed" in key] for key in time_embedding_blocks: - diffusers_key = (key.replace("time_embed.0", "time_embedding.linear_1") - .replace("time_embed.2", "time_embedding.linear_2") + diffusers_key = key.replace("time_embed.0", "time_embedding.linear_1").replace( + "time_embed.2", "time_embedding.linear_2" ) converted_state_dict[diffusers_key] = state_dict.get(key) # controlnet label embedding blocks label_embedding_blocks = [key for key in state_dict if "label_emb" in key] for key in label_embedding_blocks: - diffusers_key = (key.replace("label_emb.0.0", "add_embedding.linear_1") - .replace("label_emb.0.2", "add_embedding.linear_2") + diffusers_key = key.replace("label_emb.0.0", "add_embedding.linear_1").replace( + "label_emb.0.2", "add_embedding.linear_2" ) converted_state_dict[diffusers_key] = state_dict.get(key) @@ -338,7 +336,8 @@ def convert_sai_sd_control_lora_state_dict_to_peft(state_dict): key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key ] for key in resnets: - diffusers_key = (key.replace("in_layers.0", "norm1") + diffusers_key = ( + key.replace("in_layers.0", "norm1") .replace("in_layers.2", "conv1") .replace("out_layers.0", "norm2") .replace("out_layers.3", "conv2") @@ -352,7 +351,9 @@ def convert_sai_sd_control_lora_state_dict_to_peft(state_dict): if f"input_blocks.{i}.0.op.bias" in state_dict: for key in [key for key in op_blocks if f"input_blocks.{i}.0.op" in key]: - diffusers_key = key.replace(f"input_blocks.{i}.0.op", f"down_blocks.{block_id}.downsamplers.0.conv") + diffusers_key = key.replace( + f"input_blocks.{i}.0.op", f"down_blocks.{block_id}.downsamplers.0.conv" + ) converted_state_dict[diffusers_key] = state_dict.get(key) attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] @@ -362,16 +363,14 @@ def convert_sai_sd_control_lora_state_dict_to_peft(state_dict): f"input_blocks.{i}.1", f"down_blocks.{block_id}.attentions.{layer_in_block_id}" ) converted_state_dict[diffusers_key] = state_dict.get(key) - + # controlnet down blocks for i in range(num_input_blocks): converted_state_dict[f"controlnet_down_blocks.{i}.weight"] = state_dict.get(f"zero_convs.{i}.0.weight") converted_state_dict[f"controlnet_down_blocks.{i}.bias"] = state_dict.get(f"zero_convs.{i}.0.bias") # Retrieves the keys for the middle blocks only - num_middle_blocks = len( - {".".join(layer.split(".")[:2]) for layer in state_dict if "middle_block" in layer} - ) + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in state_dict if "middle_block" in layer}) middle_blocks = { layer_id: [key for key in state_dict if f"middle_block.{layer_id}" in key] for layer_id in range(num_middle_blocks) @@ -382,7 +381,8 @@ def convert_sai_sd_control_lora_state_dict_to_peft(state_dict): diffusers_key = max(key - 1, 0) if key % 2 == 0: for k in middle_blocks[key]: - diffusers_key_hf = (k.replace("in_layers.0", "norm1") + diffusers_key_hf = ( + k.replace("in_layers.0", "norm1") .replace("in_layers.2", "conv1") .replace("out_layers.0", "norm2") .replace("out_layers.3", "conv2") @@ -395,11 +395,9 @@ def convert_sai_sd_control_lora_state_dict_to_peft(state_dict): converted_state_dict[diffusers_key_hf] = state_dict.get(k) else: for k in middle_blocks[key]: - diffusers_key_hf = k.replace( - f"middle_block.{key}", f"mid_block.attentions.{diffusers_key}" - ) + diffusers_key_hf = k.replace(f"middle_block.{key}", f"mid_block.attentions.{diffusers_key}") converted_state_dict[diffusers_key_hf] = state_dict.get(k) - + # mid block converted_state_dict["controlnet_mid_block.weight"] = state_dict.get("middle_block_out.0.weight") converted_state_dict["controlnet_mid_block.bias"] = state_dict.get("middle_block_out.0.bias") @@ -408,7 +406,9 @@ def convert_sai_sd_control_lora_state_dict_to_peft(state_dict): cond_embedding_blocks = { ".".join(layer.split(".")[:2]) for layer in state_dict - if "input_hint_block" in layer and ("input_hint_block.0" not in layer) and ("input_hint_block.14" not in layer) + if "input_hint_block" in layer + and ("input_hint_block.0" not in layer) + and ("input_hint_block.14" not in layer) } num_cond_embedding_blocks = len(cond_embedding_blocks) @@ -422,13 +422,13 @@ def convert_sai_sd_control_lora_state_dict_to_peft(state_dict): converted_state_dict[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = state_dict.get( f"input_hint_block.{cond_block_id}.bias" ) - + for key in [key for key in state_dict if "input_hint_block.0" in key]: diffusers_key = key.replace("input_hint_block.0", "controlnet_cond_embedding.conv_in") converted_state_dict[diffusers_key] = state_dict.get(key) - + for key in [key for key in state_dict if "input_hint_block.14" in key]: - diffusers_key = key.replace(f"input_hint_block.14", "controlnet_cond_embedding.conv_out") + diffusers_key = key.replace("input_hint_block.14", "controlnet_cond_embedding.conv_out") converted_state_dict[diffusers_key] = state_dict.get(key) return converted_state_dict