1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
Dhruv Nair
2024-01-17 16:02:08 +00:00
parent 249f78e1c8
commit 8a24733654
3 changed files with 94 additions and 31 deletions

View File

@@ -27,6 +27,7 @@ from ..utils import (
logging,
)
from .single_file_utils import (
create_controlnet_model,
create_scheduler,
create_text_encoders_and_tokenizers,
create_unet_model,
@@ -318,11 +319,11 @@ class FromSingleFileMixin:
original_config = fetch_original_config(class_name, checkpoint, original_config_file, config_files)
if class_name == "AutoencoderKL":
component = build_component({}, "vae", original_config, checkpoint, pretrained_model_link_or_path)
component = create_vae_model(class_name, original_config, checkpoint, pretrained_model_link_or_path)
return component["vae"]
if class_name == "ControlNetModel":
component = build_component({}, "controlnet", original_config, checkpoint, pretrained_model_link_or_path)
component = create_controlnet_model(class_name, original_config, checkpoint, **kwargs)
return component["controlnet"]
component_names = extract_pipeline_component_names(cls)

View File

@@ -139,6 +139,10 @@ DIFFUSERS_TO_LDM_MAPPING = {
"token_embedding.weight": "transformer.text_model.embeddings.token_embedding.weight",
"positional_embedding": "transformer.text_model.embeddings.position_embedding.weight",
},
"controlnet" : {
"controlnet_cond_embedding.conv_in.weight": "input_hint_block.0.weight",
"controlnet_cond_embedding.conv_in.bias": "input_hint_block.0.bias"
}
}
LDM_VAE_KEY = "first_stage_model."
@@ -510,14 +514,16 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
new_checkpoint = {}
ldm_unet_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["layers"]
for diffusers_key, ldm_key in ldm_unet_keys.items():
if ldm_key not in unet_state_dict:
continue
new_checkpoint[diffusers_key] = unet_state_dict[ldm_key]
if config["class_embed_type"] in ["timestep", "projection"]:
if ("class_embed_type" in config) and (config["class_embed_type"] in ["timestep", "projection"]):
class_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["class_embed_type"]
for diffusers_key, ldm_key in class_embed_keys.items():
new_checkpoint[diffusers_key] = unet_state_dict[ldm_key]
if config["addition_embed_type"] == "text_time":
if ("addition_embed_type" in config) and (config["addition_embed_type"] == "text_time"):
addition_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["addition_embed_type"]
for diffusers_key, ldm_key in addition_embed_keys.items():
new_checkpoint[diffusers_key] = unet_state_dict[ldm_key]
@@ -641,16 +647,10 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
def convert_controlnet_checkpoint(
checkpoint,
original_config,
checkpoint_path,
image_size,
upcast_attention,
extract_ema,
use_linear_projection=None,
cross_attention_dim=None,
config,
):
""""
"""
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size)
ctrlnet_config["upcast_attention"] = upcast_attention
@@ -674,48 +674,108 @@ def convert_controlnet_checkpoint(
else:
skip_extract_state_dict = False
new_checkpoint = convert_ldm_unet_checkpoint(checkpoint, original_config)
new_checkpoint = {}
ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"]
for diffusers_key, ldm_key in ldm_controlnet_keys.items():
if ldm_key not in checkpoint:
continue
new_checkpoint[diffusers_key] = checkpoint[ldm_key]
# Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "input_blocks" in layer})
input_blocks = {
layer_id: [key for key in checkpoint if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}
# Down blocks
for i in range(1, num_input_blocks):
block_id = (i - 1) // (config["layers_per_block"] + 1)
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
resnets = [
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
]
update_unet_resnet_ldm_to_diffusers(
resnets,
new_checkpoint,
checkpoint,
{"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"},
)
if f"input_blocks.{i}.0.op.weight" in checkpoint:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint.pop(
f"input_blocks.{i}.0.op.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint.pop(
f"input_blocks.{i}.0.op.bias"
)
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if attentions:
update_unet_attention_ldm_to_diffusers(
attentions,
new_checkpoint,
checkpoint,
{"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"},
)
orig_index = 0
new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.weight"
)
new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.bias"
)
orig_index += 2
diffusers_index = 0
while diffusers_index < 6:
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = checkpoint.pop(
f"input_hint_block.{orig_index}.weight"
)
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = checkpoint.pop(
f"input_hint_block.{orig_index}.bias"
)
diffusers_index += 1
orig_index += 2
new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = checkpoint.pop(
f"input_hint_block.{orig_index}.weight"
)
new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = checkpoint.pop(
f"input_hint_block.{orig_index}.bias"
)
# down blocks
for i in range(num_input_blocks):
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = checkpoint.pop(f"zero_convs.{i}.0.weight")
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = checkpoint.pop(f"zero_convs.{i}.0.bias")
# mid block
new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
new_checkpoint["controlnet_mid_block.weight"] = checkpoint.pop("middle_block_out.0.weight")
new_checkpoint["controlnet_mid_block.bias"] = checkpoint.pop("middle_block_out.0.bias")
return new_checkpoint
def create_controlnet_model(
pipeline_class_name, original_config, checkpoint, **kwargs
):
from ..models import ControlNetModel
image_size = determine_image_size(pipeline_class_name, original_config, checkpoint, **kwargs)
config = create_controlnet_diffusers_config(original_config, image_size=image_size)
diffusers_format_controlnet_checkpoint = convert_controlnet_checkpoint(checkpoint, original_config)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
controlnet = ControlNetModel(**config)
if is_accelerate_available():
for param_name, param in diffusers_format_controlnet_checkpoint.items():
set_module_tensor_to_device(controlnet, param_name, "cpu", value=param)
else:
controlnet.load_state_dict(diffusers_format_controlnet_checkpoint)
return {"controlnet": controlnet}
def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
for ldm_key in keys:
diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut")
@@ -999,6 +1059,8 @@ def create_unet_model(pipeline_class_name, original_config, checkpoint, checkpoi
def create_vae_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs):
from ..models import AutoencoderKL
image_size = determine_image_size(pipeline_class_name, original_config, checkpoint, **kwargs)
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)

View File

@@ -19,7 +19,7 @@ from torch import nn
from torch.nn import functional as F
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalControlnetMixin
from ..loaders import FromSingleFileMixin
from ..utils import BaseOutput, logging
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
@@ -102,7 +102,7 @@ class ControlNetConditioningEmbedding(nn.Module):
return embedding
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
class ControlNetModel(ModelMixin, ConfigMixin, FromSingleFileMixin):
"""
A ControlNet model.