mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -26,6 +26,7 @@ from ..utils import (
|
||||
is_transformers_available,
|
||||
logging,
|
||||
)
|
||||
from ..utils.hub_utils import _get_model_file
|
||||
from .single_file_utils import (
|
||||
create_controlnet_model,
|
||||
create_scheduler,
|
||||
|
||||
@@ -29,7 +29,7 @@ from transformers import (
|
||||
CLIPTokenizer,
|
||||
)
|
||||
|
||||
from ..models import AutoencoderKL, UNet2DConditionModel
|
||||
from ..models import UNet2DConditionModel
|
||||
from ..schedulers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
@@ -105,6 +105,26 @@ DIFFUSERS_TO_LDM_MAPPING = {
|
||||
"add_embedding.linear_2.bias": "label_emb.0.2.bias",
|
||||
},
|
||||
},
|
||||
"controlnet": {
|
||||
"layers": {
|
||||
"controlnet_cond_embedding.conv_in.weight": "input_hint_block.0.weight",
|
||||
"controlnet_cond_embedding.conv_in.bias": "input_hint_block.0.bias",
|
||||
"controlnet_cond_embedding.conv_out.weight": "input_hint_block.14.weight",
|
||||
"controlnet_cond_embedding.conv_out.bias": "input_hint_block.14.bias",
|
||||
},
|
||||
"class_embed_type": {
|
||||
"class_embedding.linear_1.weight": "label_emb.0.0.weight",
|
||||
"class_embedding.linear_1.bias": "label_emb.0.0.bias",
|
||||
"class_embedding.linear_2.weight": "label_emb.0.2.weight",
|
||||
"class_embedding.linear_2.bias": "label_emb.0.2.bias",
|
||||
},
|
||||
"addition_embed_type": {
|
||||
"add_embedding.linear_1.weight": "label_emb.0.0.weight",
|
||||
"add_embedding.linear_1.bias": "label_emb.0.0.bias",
|
||||
"add_embedding.linear_2.weight": "label_emb.0.2.weight",
|
||||
"add_embedding.linear_2.bias": "label_emb.0.2.bias",
|
||||
},
|
||||
},
|
||||
"vae": {
|
||||
"encoder.conv_in.weight": "encoder.conv_in.weight",
|
||||
"encoder.conv_in.bias": "encoder.conv_in.bias",
|
||||
@@ -139,18 +159,30 @@ DIFFUSERS_TO_LDM_MAPPING = {
|
||||
"token_embedding.weight": "transformer.text_model.embeddings.token_embedding.weight",
|
||||
"positional_embedding": "transformer.text_model.embeddings.position_embedding.weight",
|
||||
},
|
||||
"controlnet" : {
|
||||
"controlnet_cond_embedding.conv_in.weight": "input_hint_block.0.weight",
|
||||
"controlnet_cond_embedding.conv_in.bias": "input_hint_block.0.bias"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
LDM_VAE_KEY = "first_stage_model."
|
||||
LDM_UNET_KEY = "model.diffusion_model."
|
||||
LDM_CONTROLNET_KEY = "control_model."
|
||||
LDM_CLIP_CONFIG_NAME = "openai/clip-vit-large-patch14"
|
||||
LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."]
|
||||
|
||||
SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = ['cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias', 'cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight', 'cond_stage_model.model.transformer.resblocks.23.attn.out_proj.bias', 'cond_stage_model.model.transformer.resblocks.23.attn.out_proj.weight', 'cond_stage_model.model.transformer.resblocks.23.ln_1.bias', 'cond_stage_model.model.transformer.resblocks.23.ln_1.weight', 'cond_stage_model.model.transformer.resblocks.23.ln_2.bias', 'cond_stage_model.model.transformer.resblocks.23.ln_2.weight', 'cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.bias', 'cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.weight', 'cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.bias', 'cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.weight', 'cond_stage_model.model.text_projection']
|
||||
SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [
|
||||
"cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias",
|
||||
"cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight",
|
||||
"cond_stage_model.model.transformer.resblocks.23.attn.out_proj.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.23.attn.out_proj.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.23.ln_1.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.23.ln_1.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.23.ln_2.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.23.ln_2.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.weight",
|
||||
"cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.bias",
|
||||
"cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.weight",
|
||||
"cond_stage_model.model.text_projection",
|
||||
]
|
||||
|
||||
textenc_conversion_lst = [
|
||||
("positional_embedding", "text_model.embeddings.position_embedding.weight"),
|
||||
@@ -424,11 +456,26 @@ def create_unet_diffusers_config(original_config, image_size: int):
|
||||
|
||||
def create_controlnet_diffusers_config(original_config, image_size: int):
|
||||
unet_params = original_config["model"]["params"]["control_stage_config"]["params"]
|
||||
config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
diffusers_unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
|
||||
config["conditioning_channels"] = unet_params["hint_channels"]
|
||||
controlnet_config = {
|
||||
"conditioning_channels": unet_params["hint_channels"],
|
||||
"in_channels": diffusers_unet_config["in_channels"],
|
||||
"down_block_types": diffusers_unet_config["down_block_types"],
|
||||
"block_out_channels": diffusers_unet_config["block_out_channels"],
|
||||
"layers_per_block": diffusers_unet_config["layers_per_block"],
|
||||
"cross_attention_dim": diffusers_unet_config["cross_attention_dim"],
|
||||
"attention_head_dim": diffusers_unet_config["attention_head_dim"],
|
||||
"use_linear_projection": diffusers_unet_config["use_linear_projection"],
|
||||
"class_embed_type": diffusers_unet_config["class_embed_type"],
|
||||
"addition_embed_type": diffusers_unet_config["addition_embed_type"],
|
||||
"addition_time_embed_dim": diffusers_unet_config["addition_time_embed_dim"],
|
||||
"projection_class_embeddings_input_dim": diffusers_unet_config["projection_class_embeddings_input_dim"],
|
||||
"transformer_layers_per_block": diffusers_unet_config["transformer_layers_per_block"],
|
||||
}
|
||||
|
||||
return controlnet_config
|
||||
|
||||
return config
|
||||
|
||||
def create_vae_diffusers_config(original_config, image_size: int):
|
||||
"""
|
||||
@@ -475,7 +522,9 @@ def update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint,
|
||||
new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key)
|
||||
|
||||
|
||||
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, skip_extract_state_dict=False):
|
||||
def convert_ldm_unet_checkpoint(
|
||||
checkpoint, config, unet_key, path=None, extract_ema=False, skip_extract_state_dict=False
|
||||
):
|
||||
"""
|
||||
Takes a state dict and a config, and returns a converted checkpoint.
|
||||
"""
|
||||
@@ -649,42 +698,32 @@ def convert_controlnet_checkpoint(
|
||||
checkpoint,
|
||||
config,
|
||||
):
|
||||
|
||||
"""
|
||||
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
ctrlnet_config["upcast_attention"] = upcast_attention
|
||||
|
||||
ctrlnet_config.pop("sample_size")
|
||||
|
||||
if use_linear_projection is not None:
|
||||
ctrlnet_config["use_linear_projection"] = use_linear_projection
|
||||
|
||||
if cross_attention_dim is not None:
|
||||
ctrlnet_config["cross_attention_dim"] = cross_attention_dim
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
controlnet = ControlNetModel(**ctrlnet_config)
|
||||
"""
|
||||
|
||||
# Some controlnet ckpt files are distributed independently from the rest of the
|
||||
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
|
||||
if "time_embed.0.weight" in checkpoint:
|
||||
skip_extract_state_dict = True
|
||||
controlnet_state_dict = checkpoint
|
||||
|
||||
else:
|
||||
skip_extract_state_dict = False
|
||||
controlnet_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
controlnet_key = LDM_CONTROLNET_KEY
|
||||
for key in keys:
|
||||
if key.startswith(controlnet_key):
|
||||
controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.pop(key)
|
||||
|
||||
new_checkpoint = {}
|
||||
ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"]
|
||||
for diffusers_key, ldm_key in ldm_controlnet_keys.items():
|
||||
if ldm_key not in checkpoint:
|
||||
if ldm_key not in controlnet_state_dict:
|
||||
continue
|
||||
new_checkpoint[diffusers_key] = checkpoint[ldm_key]
|
||||
new_checkpoint[diffusers_key] = controlnet_state_dict[ldm_key]
|
||||
|
||||
# Retrieves the keys for the input blocks only
|
||||
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "input_blocks" in layer})
|
||||
num_input_blocks = len(
|
||||
{".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "input_blocks" in layer}
|
||||
)
|
||||
input_blocks = {
|
||||
layer_id: [key for key in checkpoint if f"input_blocks.{layer_id}" in key]
|
||||
layer_id: [key for key in controlnet_state_dict if f"input_blocks.{layer_id}" in key]
|
||||
for layer_id in range(num_input_blocks)
|
||||
}
|
||||
|
||||
@@ -699,15 +738,15 @@ def convert_controlnet_checkpoint(
|
||||
update_unet_resnet_ldm_to_diffusers(
|
||||
resnets,
|
||||
new_checkpoint,
|
||||
checkpoint,
|
||||
controlnet_state_dict,
|
||||
{"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"},
|
||||
)
|
||||
|
||||
if f"input_blocks.{i}.0.op.weight" in checkpoint:
|
||||
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint.pop(
|
||||
if f"input_blocks.{i}.0.op.weight" in controlnet_state_dict:
|
||||
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.pop(
|
||||
f"input_blocks.{i}.0.op.weight"
|
||||
)
|
||||
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint.pop(
|
||||
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.pop(
|
||||
f"input_blocks.{i}.0.op.bias"
|
||||
)
|
||||
|
||||
@@ -716,55 +755,55 @@ def convert_controlnet_checkpoint(
|
||||
update_unet_attention_ldm_to_diffusers(
|
||||
attentions,
|
||||
new_checkpoint,
|
||||
checkpoint,
|
||||
controlnet_state_dict,
|
||||
{"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"},
|
||||
)
|
||||
|
||||
orig_index = 0
|
||||
orig_index += 2
|
||||
diffusers_index = 0
|
||||
|
||||
while diffusers_index < 6:
|
||||
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = checkpoint.pop(
|
||||
f"input_hint_block.{orig_index}.weight"
|
||||
)
|
||||
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = checkpoint.pop(
|
||||
f"input_hint_block.{orig_index}.bias"
|
||||
)
|
||||
diffusers_index += 1
|
||||
orig_index += 2
|
||||
|
||||
new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = checkpoint.pop(
|
||||
f"input_hint_block.{orig_index}.weight"
|
||||
)
|
||||
new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = checkpoint.pop(
|
||||
f"input_hint_block.{orig_index}.bias"
|
||||
)
|
||||
|
||||
# down blocks
|
||||
# controlnet down blocks
|
||||
for i in range(num_input_blocks):
|
||||
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = checkpoint.pop(f"zero_convs.{i}.0.weight")
|
||||
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = checkpoint.pop(f"zero_convs.{i}.0.bias")
|
||||
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.pop(f"zero_convs.{i}.0.weight")
|
||||
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.pop(f"zero_convs.{i}.0.bias")
|
||||
|
||||
# mid block
|
||||
new_checkpoint["controlnet_mid_block.weight"] = checkpoint.pop("middle_block_out.0.weight")
|
||||
new_checkpoint["controlnet_mid_block.bias"] = checkpoint.pop("middle_block_out.0.bias")
|
||||
new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.pop("middle_block_out.0.weight")
|
||||
new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.pop("middle_block_out.0.bias")
|
||||
|
||||
# controlnet cond embedding blocks
|
||||
cond_embedding_blocks = {
|
||||
".".join(layer.split(".")[:2])
|
||||
for layer in controlnet_state_dict
|
||||
if "input_hint_block" in layer and ("input_hint_block.0" not in layer) and ("input_hint_block.14" not in layer)
|
||||
}
|
||||
num_cond_embedding_blocks = len(cond_embedding_blocks)
|
||||
|
||||
for idx in range(1, num_cond_embedding_blocks):
|
||||
diffusers_idx = idx - 1
|
||||
cond_block_id = 2 * idx
|
||||
|
||||
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.pop(
|
||||
f"input_hint_block.{cond_block_id}.weight"
|
||||
)
|
||||
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.pop(
|
||||
f"input_hint_block.{cond_block_id}.bias"
|
||||
)
|
||||
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def create_controlnet_model(
|
||||
pipeline_class_name, original_config, checkpoint, **kwargs
|
||||
):
|
||||
def create_controlnet_model(pipeline_class_name, original_config, checkpoint, **kwargs):
|
||||
from ..models import ControlNetModel
|
||||
|
||||
image_size = determine_image_size(pipeline_class_name, original_config, checkpoint, **kwargs)
|
||||
config = create_controlnet_diffusers_config(original_config, image_size=image_size)
|
||||
diffusers_format_controlnet_checkpoint = convert_controlnet_checkpoint(checkpoint, original_config)
|
||||
upcast_attention = kwargs.get("upcast_attention", False)
|
||||
|
||||
diffusers_config = create_controlnet_diffusers_config(original_config, image_size=image_size)
|
||||
diffusers_config["upcast_attention"] = upcast_attention
|
||||
|
||||
diffusers_format_controlnet_checkpoint = convert_controlnet_checkpoint(checkpoint, diffusers_config)
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
controlnet = ControlNetModel(**config)
|
||||
controlnet = ControlNetModel(**diffusers_config)
|
||||
|
||||
if is_accelerate_available():
|
||||
for param_name, param in diffusers_format_controlnet_checkpoint.items():
|
||||
@@ -775,7 +814,6 @@ def create_controlnet_model(
|
||||
return {"controlnet": controlnet}
|
||||
|
||||
|
||||
|
||||
def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
|
||||
for ldm_key in keys:
|
||||
diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut")
|
||||
|
||||
Reference in New Issue
Block a user