mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update
This commit is contained in:
@@ -27,6 +27,7 @@ from ..utils import (
|
||||
logging,
|
||||
)
|
||||
from .single_file_utils import (
|
||||
create_controlnet_model,
|
||||
create_scheduler,
|
||||
create_text_encoders_and_tokenizers,
|
||||
create_unet_model,
|
||||
@@ -318,11 +319,11 @@ class FromSingleFileMixin:
|
||||
original_config = fetch_original_config(class_name, checkpoint, original_config_file, config_files)
|
||||
|
||||
if class_name == "AutoencoderKL":
|
||||
component = build_component({}, "vae", original_config, checkpoint, pretrained_model_link_or_path)
|
||||
component = create_vae_model(class_name, original_config, checkpoint, pretrained_model_link_or_path)
|
||||
return component["vae"]
|
||||
|
||||
if class_name == "ControlNetModel":
|
||||
component = build_component({}, "controlnet", original_config, checkpoint, pretrained_model_link_or_path)
|
||||
component = create_controlnet_model(class_name, original_config, checkpoint, **kwargs)
|
||||
return component["controlnet"]
|
||||
|
||||
component_names = extract_pipeline_component_names(cls)
|
||||
|
||||
@@ -139,6 +139,10 @@ DIFFUSERS_TO_LDM_MAPPING = {
|
||||
"token_embedding.weight": "transformer.text_model.embeddings.token_embedding.weight",
|
||||
"positional_embedding": "transformer.text_model.embeddings.position_embedding.weight",
|
||||
},
|
||||
"controlnet" : {
|
||||
"controlnet_cond_embedding.conv_in.weight": "input_hint_block.0.weight",
|
||||
"controlnet_cond_embedding.conv_in.bias": "input_hint_block.0.bias"
|
||||
}
|
||||
}
|
||||
|
||||
LDM_VAE_KEY = "first_stage_model."
|
||||
@@ -510,14 +514,16 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
||||
new_checkpoint = {}
|
||||
ldm_unet_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["layers"]
|
||||
for diffusers_key, ldm_key in ldm_unet_keys.items():
|
||||
if ldm_key not in unet_state_dict:
|
||||
continue
|
||||
new_checkpoint[diffusers_key] = unet_state_dict[ldm_key]
|
||||
|
||||
if config["class_embed_type"] in ["timestep", "projection"]:
|
||||
if ("class_embed_type" in config) and (config["class_embed_type"] in ["timestep", "projection"]):
|
||||
class_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["class_embed_type"]
|
||||
for diffusers_key, ldm_key in class_embed_keys.items():
|
||||
new_checkpoint[diffusers_key] = unet_state_dict[ldm_key]
|
||||
|
||||
if config["addition_embed_type"] == "text_time":
|
||||
if ("addition_embed_type" in config) and (config["addition_embed_type"] == "text_time"):
|
||||
addition_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["addition_embed_type"]
|
||||
for diffusers_key, ldm_key in addition_embed_keys.items():
|
||||
new_checkpoint[diffusers_key] = unet_state_dict[ldm_key]
|
||||
@@ -641,16 +647,10 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
||||
|
||||
def convert_controlnet_checkpoint(
|
||||
checkpoint,
|
||||
original_config,
|
||||
checkpoint_path,
|
||||
image_size,
|
||||
upcast_attention,
|
||||
extract_ema,
|
||||
use_linear_projection=None,
|
||||
cross_attention_dim=None,
|
||||
config,
|
||||
):
|
||||
|
||||
""""
|
||||
"""
|
||||
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
ctrlnet_config["upcast_attention"] = upcast_attention
|
||||
|
||||
@@ -674,48 +674,108 @@ def convert_controlnet_checkpoint(
|
||||
else:
|
||||
skip_extract_state_dict = False
|
||||
|
||||
new_checkpoint = convert_ldm_unet_checkpoint(checkpoint, original_config)
|
||||
new_checkpoint = {}
|
||||
ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"]
|
||||
for diffusers_key, ldm_key in ldm_controlnet_keys.items():
|
||||
if ldm_key not in checkpoint:
|
||||
continue
|
||||
new_checkpoint[diffusers_key] = checkpoint[ldm_key]
|
||||
|
||||
# Retrieves the keys for the input blocks only
|
||||
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "input_blocks" in layer})
|
||||
input_blocks = {
|
||||
layer_id: [key for key in checkpoint if f"input_blocks.{layer_id}" in key]
|
||||
for layer_id in range(num_input_blocks)
|
||||
}
|
||||
|
||||
# Down blocks
|
||||
for i in range(1, num_input_blocks):
|
||||
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
||||
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
||||
|
||||
resnets = [
|
||||
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
||||
]
|
||||
update_unet_resnet_ldm_to_diffusers(
|
||||
resnets,
|
||||
new_checkpoint,
|
||||
checkpoint,
|
||||
{"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"},
|
||||
)
|
||||
|
||||
if f"input_blocks.{i}.0.op.weight" in checkpoint:
|
||||
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint.pop(
|
||||
f"input_blocks.{i}.0.op.weight"
|
||||
)
|
||||
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint.pop(
|
||||
f"input_blocks.{i}.0.op.bias"
|
||||
)
|
||||
|
||||
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
||||
if attentions:
|
||||
update_unet_attention_ldm_to_diffusers(
|
||||
attentions,
|
||||
new_checkpoint,
|
||||
checkpoint,
|
||||
{"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"},
|
||||
)
|
||||
|
||||
orig_index = 0
|
||||
|
||||
new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
|
||||
f"input_hint_block.{orig_index}.weight"
|
||||
)
|
||||
new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
|
||||
f"input_hint_block.{orig_index}.bias"
|
||||
)
|
||||
|
||||
orig_index += 2
|
||||
diffusers_index = 0
|
||||
|
||||
while diffusers_index < 6:
|
||||
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
|
||||
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = checkpoint.pop(
|
||||
f"input_hint_block.{orig_index}.weight"
|
||||
)
|
||||
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
|
||||
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = checkpoint.pop(
|
||||
f"input_hint_block.{orig_index}.bias"
|
||||
)
|
||||
diffusers_index += 1
|
||||
orig_index += 2
|
||||
|
||||
new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
|
||||
new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = checkpoint.pop(
|
||||
f"input_hint_block.{orig_index}.weight"
|
||||
)
|
||||
new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
|
||||
new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = checkpoint.pop(
|
||||
f"input_hint_block.{orig_index}.bias"
|
||||
)
|
||||
|
||||
# down blocks
|
||||
for i in range(num_input_blocks):
|
||||
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
|
||||
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
|
||||
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = checkpoint.pop(f"zero_convs.{i}.0.weight")
|
||||
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = checkpoint.pop(f"zero_convs.{i}.0.bias")
|
||||
|
||||
# mid block
|
||||
new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
|
||||
new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
|
||||
new_checkpoint["controlnet_mid_block.weight"] = checkpoint.pop("middle_block_out.0.weight")
|
||||
new_checkpoint["controlnet_mid_block.bias"] = checkpoint.pop("middle_block_out.0.bias")
|
||||
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def create_controlnet_model(
|
||||
pipeline_class_name, original_config, checkpoint, **kwargs
|
||||
):
|
||||
from ..models import ControlNetModel
|
||||
|
||||
image_size = determine_image_size(pipeline_class_name, original_config, checkpoint, **kwargs)
|
||||
config = create_controlnet_diffusers_config(original_config, image_size=image_size)
|
||||
diffusers_format_controlnet_checkpoint = convert_controlnet_checkpoint(checkpoint, original_config)
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
controlnet = ControlNetModel(**config)
|
||||
|
||||
if is_accelerate_available():
|
||||
for param_name, param in diffusers_format_controlnet_checkpoint.items():
|
||||
set_module_tensor_to_device(controlnet, param_name, "cpu", value=param)
|
||||
else:
|
||||
controlnet.load_state_dict(diffusers_format_controlnet_checkpoint)
|
||||
|
||||
return {"controlnet": controlnet}
|
||||
|
||||
|
||||
|
||||
def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
|
||||
for ldm_key in keys:
|
||||
diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut")
|
||||
@@ -999,6 +1059,8 @@ def create_unet_model(pipeline_class_name, original_config, checkpoint, checkpoi
|
||||
|
||||
|
||||
def create_vae_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs):
|
||||
from ..models import AutoencoderKL
|
||||
|
||||
image_size = determine_image_size(pipeline_class_name, original_config, checkpoint, **kwargs)
|
||||
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||
|
||||
@@ -19,7 +19,7 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import FromOriginalControlnetMixin
|
||||
from ..loaders import FromSingleFileMixin
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
@@ -102,7 +102,7 @@ class ControlNetConditioningEmbedding(nn.Module):
|
||||
return embedding
|
||||
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromSingleFileMixin):
|
||||
"""
|
||||
A ControlNet model.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user