1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Dhruv Nair
2024-01-18 04:12:20 +00:00
parent 8a24733654
commit de77ff6831
2 changed files with 111 additions and 72 deletions

View File

@@ -26,6 +26,7 @@ from ..utils import (
is_transformers_available,
logging,
)
from ..utils.hub_utils import _get_model_file
from .single_file_utils import (
create_controlnet_model,
create_scheduler,

View File

@@ -29,7 +29,7 @@ from transformers import (
CLIPTokenizer,
)
from ..models import AutoencoderKL, UNet2DConditionModel
from ..models import UNet2DConditionModel
from ..schedulers import (
DDIMScheduler,
DDPMScheduler,
@@ -105,6 +105,26 @@ DIFFUSERS_TO_LDM_MAPPING = {
"add_embedding.linear_2.bias": "label_emb.0.2.bias",
},
},
"controlnet": {
"layers": {
"controlnet_cond_embedding.conv_in.weight": "input_hint_block.0.weight",
"controlnet_cond_embedding.conv_in.bias": "input_hint_block.0.bias",
"controlnet_cond_embedding.conv_out.weight": "input_hint_block.14.weight",
"controlnet_cond_embedding.conv_out.bias": "input_hint_block.14.bias",
},
"class_embed_type": {
"class_embedding.linear_1.weight": "label_emb.0.0.weight",
"class_embedding.linear_1.bias": "label_emb.0.0.bias",
"class_embedding.linear_2.weight": "label_emb.0.2.weight",
"class_embedding.linear_2.bias": "label_emb.0.2.bias",
},
"addition_embed_type": {
"add_embedding.linear_1.weight": "label_emb.0.0.weight",
"add_embedding.linear_1.bias": "label_emb.0.0.bias",
"add_embedding.linear_2.weight": "label_emb.0.2.weight",
"add_embedding.linear_2.bias": "label_emb.0.2.bias",
},
},
"vae": {
"encoder.conv_in.weight": "encoder.conv_in.weight",
"encoder.conv_in.bias": "encoder.conv_in.bias",
@@ -139,18 +159,30 @@ DIFFUSERS_TO_LDM_MAPPING = {
"token_embedding.weight": "transformer.text_model.embeddings.token_embedding.weight",
"positional_embedding": "transformer.text_model.embeddings.position_embedding.weight",
},
"controlnet" : {
"controlnet_cond_embedding.conv_in.weight": "input_hint_block.0.weight",
"controlnet_cond_embedding.conv_in.bias": "input_hint_block.0.bias"
}
}
LDM_VAE_KEY = "first_stage_model."
LDM_UNET_KEY = "model.diffusion_model."
LDM_CONTROLNET_KEY = "control_model."
LDM_CLIP_CONFIG_NAME = "openai/clip-vit-large-patch14"
LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."]
SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = ['cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias', 'cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight', 'cond_stage_model.model.transformer.resblocks.23.attn.out_proj.bias', 'cond_stage_model.model.transformer.resblocks.23.attn.out_proj.weight', 'cond_stage_model.model.transformer.resblocks.23.ln_1.bias', 'cond_stage_model.model.transformer.resblocks.23.ln_1.weight', 'cond_stage_model.model.transformer.resblocks.23.ln_2.bias', 'cond_stage_model.model.transformer.resblocks.23.ln_2.weight', 'cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.bias', 'cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.weight', 'cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.bias', 'cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.weight', 'cond_stage_model.model.text_projection']
SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [
"cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias",
"cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight",
"cond_stage_model.model.transformer.resblocks.23.attn.out_proj.bias",
"cond_stage_model.model.transformer.resblocks.23.attn.out_proj.weight",
"cond_stage_model.model.transformer.resblocks.23.ln_1.bias",
"cond_stage_model.model.transformer.resblocks.23.ln_1.weight",
"cond_stage_model.model.transformer.resblocks.23.ln_2.bias",
"cond_stage_model.model.transformer.resblocks.23.ln_2.weight",
"cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.bias",
"cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.weight",
"cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.bias",
"cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.weight",
"cond_stage_model.model.text_projection",
]
textenc_conversion_lst = [
("positional_embedding", "text_model.embeddings.position_embedding.weight"),
@@ -424,11 +456,26 @@ def create_unet_diffusers_config(original_config, image_size: int):
def create_controlnet_diffusers_config(original_config, image_size: int):
unet_params = original_config["model"]["params"]["control_stage_config"]["params"]
config = create_unet_diffusers_config(original_config, image_size=image_size)
diffusers_unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
config["conditioning_channels"] = unet_params["hint_channels"]
controlnet_config = {
"conditioning_channels": unet_params["hint_channels"],
"in_channels": diffusers_unet_config["in_channels"],
"down_block_types": diffusers_unet_config["down_block_types"],
"block_out_channels": diffusers_unet_config["block_out_channels"],
"layers_per_block": diffusers_unet_config["layers_per_block"],
"cross_attention_dim": diffusers_unet_config["cross_attention_dim"],
"attention_head_dim": diffusers_unet_config["attention_head_dim"],
"use_linear_projection": diffusers_unet_config["use_linear_projection"],
"class_embed_type": diffusers_unet_config["class_embed_type"],
"addition_embed_type": diffusers_unet_config["addition_embed_type"],
"addition_time_embed_dim": diffusers_unet_config["addition_time_embed_dim"],
"projection_class_embeddings_input_dim": diffusers_unet_config["projection_class_embeddings_input_dim"],
"transformer_layers_per_block": diffusers_unet_config["transformer_layers_per_block"],
}
return controlnet_config
return config
def create_vae_diffusers_config(original_config, image_size: int):
"""
@@ -475,7 +522,9 @@ def update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint,
new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key)
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, skip_extract_state_dict=False):
def convert_ldm_unet_checkpoint(
checkpoint, config, unet_key, path=None, extract_ema=False, skip_extract_state_dict=False
):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
@@ -649,42 +698,32 @@ def convert_controlnet_checkpoint(
checkpoint,
config,
):
"""
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size)
ctrlnet_config["upcast_attention"] = upcast_attention
ctrlnet_config.pop("sample_size")
if use_linear_projection is not None:
ctrlnet_config["use_linear_projection"] = use_linear_projection
if cross_attention_dim is not None:
ctrlnet_config["cross_attention_dim"] = cross_attention_dim
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
controlnet = ControlNetModel(**ctrlnet_config)
"""
# Some controlnet ckpt files are distributed independently from the rest of the
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
if "time_embed.0.weight" in checkpoint:
skip_extract_state_dict = True
controlnet_state_dict = checkpoint
else:
skip_extract_state_dict = False
controlnet_state_dict = {}
keys = list(checkpoint.keys())
controlnet_key = LDM_CONTROLNET_KEY
for key in keys:
if key.startswith(controlnet_key):
controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.pop(key)
new_checkpoint = {}
ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"]
for diffusers_key, ldm_key in ldm_controlnet_keys.items():
if ldm_key not in checkpoint:
if ldm_key not in controlnet_state_dict:
continue
new_checkpoint[diffusers_key] = checkpoint[ldm_key]
new_checkpoint[diffusers_key] = controlnet_state_dict[ldm_key]
# Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "input_blocks" in layer})
num_input_blocks = len(
{".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "input_blocks" in layer}
)
input_blocks = {
layer_id: [key for key in checkpoint if f"input_blocks.{layer_id}" in key]
layer_id: [key for key in controlnet_state_dict if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}
@@ -699,15 +738,15 @@ def convert_controlnet_checkpoint(
update_unet_resnet_ldm_to_diffusers(
resnets,
new_checkpoint,
checkpoint,
controlnet_state_dict,
{"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"},
)
if f"input_blocks.{i}.0.op.weight" in checkpoint:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint.pop(
if f"input_blocks.{i}.0.op.weight" in controlnet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.pop(
f"input_blocks.{i}.0.op.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint.pop(
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.pop(
f"input_blocks.{i}.0.op.bias"
)
@@ -716,55 +755,55 @@ def convert_controlnet_checkpoint(
update_unet_attention_ldm_to_diffusers(
attentions,
new_checkpoint,
checkpoint,
controlnet_state_dict,
{"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"},
)
orig_index = 0
orig_index += 2
diffusers_index = 0
while diffusers_index < 6:
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = checkpoint.pop(
f"input_hint_block.{orig_index}.weight"
)
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = checkpoint.pop(
f"input_hint_block.{orig_index}.bias"
)
diffusers_index += 1
orig_index += 2
new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = checkpoint.pop(
f"input_hint_block.{orig_index}.weight"
)
new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = checkpoint.pop(
f"input_hint_block.{orig_index}.bias"
)
# down blocks
# controlnet down blocks
for i in range(num_input_blocks):
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = checkpoint.pop(f"zero_convs.{i}.0.weight")
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = checkpoint.pop(f"zero_convs.{i}.0.bias")
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.pop(f"zero_convs.{i}.0.weight")
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.pop(f"zero_convs.{i}.0.bias")
# mid block
new_checkpoint["controlnet_mid_block.weight"] = checkpoint.pop("middle_block_out.0.weight")
new_checkpoint["controlnet_mid_block.bias"] = checkpoint.pop("middle_block_out.0.bias")
new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.pop("middle_block_out.0.weight")
new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.pop("middle_block_out.0.bias")
# controlnet cond embedding blocks
cond_embedding_blocks = {
".".join(layer.split(".")[:2])
for layer in controlnet_state_dict
if "input_hint_block" in layer and ("input_hint_block.0" not in layer) and ("input_hint_block.14" not in layer)
}
num_cond_embedding_blocks = len(cond_embedding_blocks)
for idx in range(1, num_cond_embedding_blocks):
diffusers_idx = idx - 1
cond_block_id = 2 * idx
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.pop(
f"input_hint_block.{cond_block_id}.weight"
)
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.pop(
f"input_hint_block.{cond_block_id}.bias"
)
return new_checkpoint
def create_controlnet_model(
pipeline_class_name, original_config, checkpoint, **kwargs
):
def create_controlnet_model(pipeline_class_name, original_config, checkpoint, **kwargs):
from ..models import ControlNetModel
image_size = determine_image_size(pipeline_class_name, original_config, checkpoint, **kwargs)
config = create_controlnet_diffusers_config(original_config, image_size=image_size)
diffusers_format_controlnet_checkpoint = convert_controlnet_checkpoint(checkpoint, original_config)
upcast_attention = kwargs.get("upcast_attention", False)
diffusers_config = create_controlnet_diffusers_config(original_config, image_size=image_size)
diffusers_config["upcast_attention"] = upcast_attention
diffusers_format_controlnet_checkpoint = convert_controlnet_checkpoint(checkpoint, diffusers_config)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
controlnet = ControlNetModel(**config)
controlnet = ControlNetModel(**diffusers_config)
if is_accelerate_available():
for param_name, param in diffusers_format_controlnet_checkpoint.items():
@@ -775,7 +814,6 @@ def create_controlnet_model(
return {"controlnet": controlnet}
def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
for ldm_key in keys:
diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut")