From 8a24733654ebdcfc107c307a58fbdb7610aad653 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 17 Jan 2024 16:02:08 +0000 Subject: [PATCH] update --- src/diffusers/loaders/single_file.py | 5 +- src/diffusers/loaders/single_file_utils.py | 116 ++++++++++++++++----- src/diffusers/models/controlnet.py | 4 +- 3 files changed, 94 insertions(+), 31 deletions(-) diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index 2e89de85b5..78659de332 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -27,6 +27,7 @@ from ..utils import ( logging, ) from .single_file_utils import ( + create_controlnet_model, create_scheduler, create_text_encoders_and_tokenizers, create_unet_model, @@ -318,11 +319,11 @@ class FromSingleFileMixin: original_config = fetch_original_config(class_name, checkpoint, original_config_file, config_files) if class_name == "AutoencoderKL": - component = build_component({}, "vae", original_config, checkpoint, pretrained_model_link_or_path) + component = create_vae_model(class_name, original_config, checkpoint, pretrained_model_link_or_path) return component["vae"] if class_name == "ControlNetModel": - component = build_component({}, "controlnet", original_config, checkpoint, pretrained_model_link_or_path) + component = create_controlnet_model(class_name, original_config, checkpoint, **kwargs) return component["controlnet"] component_names = extract_pipeline_component_names(cls) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 71fa08c189..fae6159515 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -139,6 +139,10 @@ DIFFUSERS_TO_LDM_MAPPING = { "token_embedding.weight": "transformer.text_model.embeddings.token_embedding.weight", "positional_embedding": "transformer.text_model.embeddings.position_embedding.weight", }, + "controlnet" : { + "controlnet_cond_embedding.conv_in.weight": "input_hint_block.0.weight", + "controlnet_cond_embedding.conv_in.bias": "input_hint_block.0.bias" + } } LDM_VAE_KEY = "first_stage_model." @@ -510,14 +514,16 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False new_checkpoint = {} ldm_unet_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["layers"] for diffusers_key, ldm_key in ldm_unet_keys.items(): + if ldm_key not in unet_state_dict: + continue new_checkpoint[diffusers_key] = unet_state_dict[ldm_key] - if config["class_embed_type"] in ["timestep", "projection"]: + if ("class_embed_type" in config) and (config["class_embed_type"] in ["timestep", "projection"]): class_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["class_embed_type"] for diffusers_key, ldm_key in class_embed_keys.items(): new_checkpoint[diffusers_key] = unet_state_dict[ldm_key] - if config["addition_embed_type"] == "text_time": + if ("addition_embed_type" in config) and (config["addition_embed_type"] == "text_time"): addition_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["addition_embed_type"] for diffusers_key, ldm_key in addition_embed_keys.items(): new_checkpoint[diffusers_key] = unet_state_dict[ldm_key] @@ -641,16 +647,10 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False def convert_controlnet_checkpoint( checkpoint, - original_config, - checkpoint_path, - image_size, - upcast_attention, - extract_ema, - use_linear_projection=None, - cross_attention_dim=None, + config, ): - """" + """ ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size) ctrlnet_config["upcast_attention"] = upcast_attention @@ -674,48 +674,108 @@ def convert_controlnet_checkpoint( else: skip_extract_state_dict = False - new_checkpoint = convert_ldm_unet_checkpoint(checkpoint, original_config) + new_checkpoint = {} + ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"] + for diffusers_key, ldm_key in ldm_controlnet_keys.items(): + if ldm_key not in checkpoint: + continue + new_checkpoint[diffusers_key] = checkpoint[ldm_key] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in checkpoint if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Down blocks + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + update_unet_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + checkpoint, + {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}, + ) + + if f"input_blocks.{i}.0.op.weight" in checkpoint: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint.pop( + f"input_blocks.{i}.0.op.bias" + ) + + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + if attentions: + update_unet_attention_ldm_to_diffusers( + attentions, + new_checkpoint, + checkpoint, + {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}, + ) + orig_index = 0 - - new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - orig_index += 2 diffusers_index = 0 while diffusers_index < 6: - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = checkpoint.pop( f"input_hint_block.{orig_index}.weight" ) - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = checkpoint.pop( f"input_hint_block.{orig_index}.bias" ) diffusers_index += 1 orig_index += 2 - new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( + new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = checkpoint.pop( f"input_hint_block.{orig_index}.weight" ) - new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( + new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = checkpoint.pop( f"input_hint_block.{orig_index}.bias" ) # down blocks for i in range(num_input_blocks): - new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") - new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") + new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = checkpoint.pop(f"zero_convs.{i}.0.weight") + new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = checkpoint.pop(f"zero_convs.{i}.0.bias") # mid block - new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") - new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") + new_checkpoint["controlnet_mid_block.weight"] = checkpoint.pop("middle_block_out.0.weight") + new_checkpoint["controlnet_mid_block.bias"] = checkpoint.pop("middle_block_out.0.bias") return new_checkpoint +def create_controlnet_model( + pipeline_class_name, original_config, checkpoint, **kwargs +): + from ..models import ControlNetModel + + image_size = determine_image_size(pipeline_class_name, original_config, checkpoint, **kwargs) + config = create_controlnet_diffusers_config(original_config, image_size=image_size) + diffusers_format_controlnet_checkpoint = convert_controlnet_checkpoint(checkpoint, original_config) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + controlnet = ControlNetModel(**config) + + if is_accelerate_available(): + for param_name, param in diffusers_format_controlnet_checkpoint.items(): + set_module_tensor_to_device(controlnet, param_name, "cpu", value=param) + else: + controlnet.load_state_dict(diffusers_format_controlnet_checkpoint) + + return {"controlnet": controlnet} + + + def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping): for ldm_key in keys: diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut") @@ -999,6 +1059,8 @@ def create_unet_model(pipeline_class_name, original_config, checkpoint, checkpoi def create_vae_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs): + from ..models import AutoencoderKL + image_size = determine_image_size(pipeline_class_name, original_config, checkpoint, **kwargs) vae_config = create_vae_diffusers_config(original_config, image_size=image_size) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 1102f4f9d3..8af13a6ec7 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -19,7 +19,7 @@ from torch import nn from torch.nn import functional as F from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import FromOriginalControlnetMixin +from ..loaders import FromSingleFileMixin from ..utils import BaseOutput, logging from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -102,7 +102,7 @@ class ControlNetConditioningEmbedding(nn.Module): return embedding -class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): +class ControlNetModel(ModelMixin, ConfigMixin, FromSingleFileMixin): """ A ControlNet model.