diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index 78659de332..723fe9462c 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -26,6 +26,7 @@ from ..utils import ( is_transformers_available, logging, ) +from ..utils.hub_utils import _get_model_file from .single_file_utils import ( create_controlnet_model, create_scheduler, diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index fae6159515..7ad49cc58b 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -29,7 +29,7 @@ from transformers import ( CLIPTokenizer, ) -from ..models import AutoencoderKL, UNet2DConditionModel +from ..models import UNet2DConditionModel from ..schedulers import ( DDIMScheduler, DDPMScheduler, @@ -105,6 +105,26 @@ DIFFUSERS_TO_LDM_MAPPING = { "add_embedding.linear_2.bias": "label_emb.0.2.bias", }, }, + "controlnet": { + "layers": { + "controlnet_cond_embedding.conv_in.weight": "input_hint_block.0.weight", + "controlnet_cond_embedding.conv_in.bias": "input_hint_block.0.bias", + "controlnet_cond_embedding.conv_out.weight": "input_hint_block.14.weight", + "controlnet_cond_embedding.conv_out.bias": "input_hint_block.14.bias", + }, + "class_embed_type": { + "class_embedding.linear_1.weight": "label_emb.0.0.weight", + "class_embedding.linear_1.bias": "label_emb.0.0.bias", + "class_embedding.linear_2.weight": "label_emb.0.2.weight", + "class_embedding.linear_2.bias": "label_emb.0.2.bias", + }, + "addition_embed_type": { + "add_embedding.linear_1.weight": "label_emb.0.0.weight", + "add_embedding.linear_1.bias": "label_emb.0.0.bias", + "add_embedding.linear_2.weight": "label_emb.0.2.weight", + "add_embedding.linear_2.bias": "label_emb.0.2.bias", + }, + }, "vae": { "encoder.conv_in.weight": "encoder.conv_in.weight", "encoder.conv_in.bias": "encoder.conv_in.bias", @@ -139,18 +159,30 @@ DIFFUSERS_TO_LDM_MAPPING = { "token_embedding.weight": "transformer.text_model.embeddings.token_embedding.weight", "positional_embedding": "transformer.text_model.embeddings.position_embedding.weight", }, - "controlnet" : { - "controlnet_cond_embedding.conv_in.weight": "input_hint_block.0.weight", - "controlnet_cond_embedding.conv_in.bias": "input_hint_block.0.bias" - } } + LDM_VAE_KEY = "first_stage_model." LDM_UNET_KEY = "model.diffusion_model." +LDM_CONTROLNET_KEY = "control_model." LDM_CLIP_CONFIG_NAME = "openai/clip-vit-large-patch14" LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."] -SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = ['cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias', 'cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight', 'cond_stage_model.model.transformer.resblocks.23.attn.out_proj.bias', 'cond_stage_model.model.transformer.resblocks.23.attn.out_proj.weight', 'cond_stage_model.model.transformer.resblocks.23.ln_1.bias', 'cond_stage_model.model.transformer.resblocks.23.ln_1.weight', 'cond_stage_model.model.transformer.resblocks.23.ln_2.bias', 'cond_stage_model.model.transformer.resblocks.23.ln_2.weight', 'cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.bias', 'cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.weight', 'cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.bias', 'cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.weight', 'cond_stage_model.model.text_projection'] +SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [ + "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias", + "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight", + "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.23.ln_1.bias", + "cond_stage_model.model.transformer.resblocks.23.ln_1.weight", + "cond_stage_model.model.transformer.resblocks.23.ln_2.bias", + "cond_stage_model.model.transformer.resblocks.23.ln_2.weight", + "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.bias", + "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.weight", + "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.bias", + "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.weight", + "cond_stage_model.model.text_projection", +] textenc_conversion_lst = [ ("positional_embedding", "text_model.embeddings.position_embedding.weight"), @@ -424,11 +456,26 @@ def create_unet_diffusers_config(original_config, image_size: int): def create_controlnet_diffusers_config(original_config, image_size: int): unet_params = original_config["model"]["params"]["control_stage_config"]["params"] - config = create_unet_diffusers_config(original_config, image_size=image_size) + diffusers_unet_config = create_unet_diffusers_config(original_config, image_size=image_size) - config["conditioning_channels"] = unet_params["hint_channels"] + controlnet_config = { + "conditioning_channels": unet_params["hint_channels"], + "in_channels": diffusers_unet_config["in_channels"], + "down_block_types": diffusers_unet_config["down_block_types"], + "block_out_channels": diffusers_unet_config["block_out_channels"], + "layers_per_block": diffusers_unet_config["layers_per_block"], + "cross_attention_dim": diffusers_unet_config["cross_attention_dim"], + "attention_head_dim": diffusers_unet_config["attention_head_dim"], + "use_linear_projection": diffusers_unet_config["use_linear_projection"], + "class_embed_type": diffusers_unet_config["class_embed_type"], + "addition_embed_type": diffusers_unet_config["addition_embed_type"], + "addition_time_embed_dim": diffusers_unet_config["addition_time_embed_dim"], + "projection_class_embeddings_input_dim": diffusers_unet_config["projection_class_embeddings_input_dim"], + "transformer_layers_per_block": diffusers_unet_config["transformer_layers_per_block"], + } + + return controlnet_config - return config def create_vae_diffusers_config(original_config, image_size: int): """ @@ -475,7 +522,9 @@ def update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key) -def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, skip_extract_state_dict=False): +def convert_ldm_unet_checkpoint( + checkpoint, config, unet_key, path=None, extract_ema=False, skip_extract_state_dict=False +): """ Takes a state dict and a config, and returns a converted checkpoint. """ @@ -649,42 +698,32 @@ def convert_controlnet_checkpoint( checkpoint, config, ): - - """ - ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size) - ctrlnet_config["upcast_attention"] = upcast_attention - - ctrlnet_config.pop("sample_size") - - if use_linear_projection is not None: - ctrlnet_config["use_linear_projection"] = use_linear_projection - - if cross_attention_dim is not None: - ctrlnet_config["cross_attention_dim"] = cross_attention_dim - - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - controlnet = ControlNetModel(**ctrlnet_config) - """ - # Some controlnet ckpt files are distributed independently from the rest of the # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ if "time_embed.0.weight" in checkpoint: - skip_extract_state_dict = True + controlnet_state_dict = checkpoint + else: - skip_extract_state_dict = False + controlnet_state_dict = {} + keys = list(checkpoint.keys()) + controlnet_key = LDM_CONTROLNET_KEY + for key in keys: + if key.startswith(controlnet_key): + controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.pop(key) new_checkpoint = {} ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"] for diffusers_key, ldm_key in ldm_controlnet_keys.items(): - if ldm_key not in checkpoint: + if ldm_key not in controlnet_state_dict: continue - new_checkpoint[diffusers_key] = checkpoint[ldm_key] + new_checkpoint[diffusers_key] = controlnet_state_dict[ldm_key] # Retrieves the keys for the input blocks only - num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in checkpoint if "input_blocks" in layer}) + num_input_blocks = len( + {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "input_blocks" in layer} + ) input_blocks = { - layer_id: [key for key in checkpoint if f"input_blocks.{layer_id}" in key] + layer_id: [key for key in controlnet_state_dict if f"input_blocks.{layer_id}" in key] for layer_id in range(num_input_blocks) } @@ -699,15 +738,15 @@ def convert_controlnet_checkpoint( update_unet_resnet_ldm_to_diffusers( resnets, new_checkpoint, - checkpoint, + controlnet_state_dict, {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}, ) - if f"input_blocks.{i}.0.op.weight" in checkpoint: - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint.pop( + if f"input_blocks.{i}.0.op.weight" in controlnet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.pop( f"input_blocks.{i}.0.op.weight" ) - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint.pop( + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.pop( f"input_blocks.{i}.0.op.bias" ) @@ -716,55 +755,55 @@ def convert_controlnet_checkpoint( update_unet_attention_ldm_to_diffusers( attentions, new_checkpoint, - checkpoint, + controlnet_state_dict, {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}, ) - orig_index = 0 - orig_index += 2 - diffusers_index = 0 - - while diffusers_index < 6: - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = checkpoint.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = checkpoint.pop( - f"input_hint_block.{orig_index}.bias" - ) - diffusers_index += 1 - orig_index += 2 - - new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = checkpoint.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = checkpoint.pop( - f"input_hint_block.{orig_index}.bias" - ) - - # down blocks + # controlnet down blocks for i in range(num_input_blocks): - new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = checkpoint.pop(f"zero_convs.{i}.0.weight") - new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = checkpoint.pop(f"zero_convs.{i}.0.bias") + new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.pop(f"zero_convs.{i}.0.weight") + new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.pop(f"zero_convs.{i}.0.bias") # mid block - new_checkpoint["controlnet_mid_block.weight"] = checkpoint.pop("middle_block_out.0.weight") - new_checkpoint["controlnet_mid_block.bias"] = checkpoint.pop("middle_block_out.0.bias") + new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.pop("middle_block_out.0.weight") + new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.pop("middle_block_out.0.bias") + + # controlnet cond embedding blocks + cond_embedding_blocks = { + ".".join(layer.split(".")[:2]) + for layer in controlnet_state_dict + if "input_hint_block" in layer and ("input_hint_block.0" not in layer) and ("input_hint_block.14" not in layer) + } + num_cond_embedding_blocks = len(cond_embedding_blocks) + + for idx in range(1, num_cond_embedding_blocks): + diffusers_idx = idx - 1 + cond_block_id = 2 * idx + + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.pop( + f"input_hint_block.{cond_block_id}.weight" + ) + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.pop( + f"input_hint_block.{cond_block_id}.bias" + ) return new_checkpoint -def create_controlnet_model( - pipeline_class_name, original_config, checkpoint, **kwargs -): +def create_controlnet_model(pipeline_class_name, original_config, checkpoint, **kwargs): from ..models import ControlNetModel image_size = determine_image_size(pipeline_class_name, original_config, checkpoint, **kwargs) - config = create_controlnet_diffusers_config(original_config, image_size=image_size) - diffusers_format_controlnet_checkpoint = convert_controlnet_checkpoint(checkpoint, original_config) + upcast_attention = kwargs.get("upcast_attention", False) + + diffusers_config = create_controlnet_diffusers_config(original_config, image_size=image_size) + diffusers_config["upcast_attention"] = upcast_attention + + diffusers_format_controlnet_checkpoint = convert_controlnet_checkpoint(checkpoint, diffusers_config) ctx = init_empty_weights if is_accelerate_available() else nullcontext with ctx(): - controlnet = ControlNetModel(**config) + controlnet = ControlNetModel(**diffusers_config) if is_accelerate_available(): for param_name, param in diffusers_format_controlnet_checkpoint.items(): @@ -775,7 +814,6 @@ def create_controlnet_model( return {"controlnet": controlnet} - def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping): for ldm_key in keys: diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut")