From 2686fddbf1e358dd017ccd43cfc587f9c481463d Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 22 Dec 2023 13:54:32 +0000 Subject: [PATCH] update --- src/diffusers/loaders/single_file.py | 96 +++++--- src/diffusers/loaders/single_file_utils.py | 260 +++++++++++++++++++-- 2 files changed, 302 insertions(+), 54 deletions(-) diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index 327e3ba29d..4fb539f853 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -40,6 +40,62 @@ if is_accelerate_available(): logger = logging.get_logger(__name__) +DIFFUSER_PIPELINE_CONFIGS = { + "StableDiffusionPipeline": None, + "StableDiffusionImg2ImgPipeline": None, + "StableDiffusionInpaintPipeline": None, + "StableDiffusionControlNetPipeline": None, +} + +VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"] +MODEL_TYPE_FROM_PIPELINE_CLASS = { + "StableUnCLIPPipeline": "FrozenOpenCLIPEmbedder", + "StableUnCLIPImg2ImgPipeline": "FrozenOpenCLIPEmbedder", +} + + + +def check_valid_url(pretrained_model_link_or_path): + # remove huggingface url + has_valid_url_prefix = False + for prefix in VALID_URL_PREFIXES: + if pretrained_model_link_or_path.startswith(prefix): + pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :] + has_valid_url_prefix = True + + return has_valid_url_prefix + + +def fetch_model_checkpoint(ckpt_path, cache_dir=None, resume_download=False, force_download=False, proxies=None, local_files_only=None, token=None, revision=None): + # get repo_id and (potentially nested) file path of ckpt in repo + repo_id = "/".join(ckpt_path.parts[:2]) + file_path = "/".join(ckpt_path.parts[2:]) + + if file_path.startswith("blob/"): + file_path = file_path[len("blob/") :] + + if file_path.startswith("main/"): + file_path = file_path[len("main/") :] + + path = hf_hub_download( + repo_id, + filename=file_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + force_download=force_download, + ) + + return path + + +def infer_model_type(pipeline_class_name): + return MODEL_TYPE_FROM_PIPELINE_CLASS.get(pipeline_class_name, None) + + class FromSingleFileMixin: """ Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`]. @@ -150,12 +206,10 @@ class FromSingleFileMixin: >>> pipeline.to("cuda") ``` """ - original_config_file = kwargs.pop("original_config_file", None) config_files = kwargs.pop("config_files", None) cache_dir = kwargs.pop("cache_dir", None) resume_download = kwargs.pop("resume_download", False) - force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) @@ -221,43 +275,15 @@ class FromSingleFileMixin: else: raise ValueError(f"Unhandled pipeline class: {pipeline_name}") - # remove huggingface url - has_valid_url_prefix = False - valid_url_prefixes = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"] - for prefix in valid_url_prefixes: - if pretrained_model_link_or_path.startswith(prefix): - pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :] - has_valid_url_prefix = True + has_valid_url_prefix = check_valid_url(pretrained_model_link_or_path) # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained ckpt_path = Path(pretrained_model_link_or_path) - if not ckpt_path.is_file(): - if not has_valid_url_prefix: - raise ValueError( - f"The provided path is either not a file or a valid huggingface URL was not provided. Valid URLs begin with {', '.join(valid_url_prefixes)}" - ) - - # get repo_id and (potentially nested) file path of ckpt in repo - repo_id = "/".join(ckpt_path.parts[:2]) - file_path = "/".join(ckpt_path.parts[2:]) - - if file_path.startswith("blob/"): - file_path = file_path[len("blob/") :] - - if file_path.startswith("main/"): - file_path = file_path[len("main/") :] - - pretrained_model_link_or_path = hf_hub_download( - repo_id, - filename=file_path, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - force_download=force_download, + if (not ckpt_path.is_file()) and (not has_valid_url_prefix): + raise ValueError( + f"The provided path is either not a file or a valid huggingface URL was not provided. Valid URLs begin with {', '.join(VALID_URL_PREFIXES)}" ) + pretrained_model_link_or_path = fetch_model_checkpoint(ckpt_path, cache_dir=cache_dir, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision) pipe = download_from_original_stable_diffusion_ckpt( pretrained_model_link_or_path, diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 3edcaca28b..0effa4d826 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -14,7 +14,6 @@ # limitations under the License. """ Conversion script for the Stable Diffusion checkpoints.""" -import re from contextlib import nullcontext from io import BytesIO from typing import Dict, Optional, Union @@ -26,18 +25,12 @@ from safetensors.torch import load_file as safe_load from transformers import ( AutoFeatureExtractor, BertTokenizerFast, - CLIPImageProcessor, - CLIPTextConfig, - CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, - CLIPVisionConfig, - CLIPVisionModelWithProjection, ) from ...models import ( AutoencoderKL, - ControlNetModel, PriorTransformer, UNet2DConditionModel, ) @@ -54,11 +47,8 @@ from ...schedulers import ( ) from ...utils import is_accelerate_available, is_omegaconf_available, logging from ...utils.import_utils import BACKENDS_MAPPING -from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel -from ..paint_by_example import PaintByExampleImageEncoder from ..pipeline_utils import DiffusionPipeline from .safety_checker import StableDiffusionSafetyChecker -from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer if is_accelerate_available(): @@ -147,7 +137,7 @@ def load_checkpoint(checkpoint_path_or_dict, device=None, from_safetensors=True) return checkpoint -def get_model_type(original_config, model_type=None): +def set_model_type(original_config, model_type=None): if model_type is not None: return model_type @@ -710,6 +700,242 @@ def convert_ldm_unet_checkpoint( return new_checkpoint + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + keys = list(checkpoint.keys()) + vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else "" + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def convert_ldm_bert_checkpoint(checkpoint, config): + def _copy_attn_layer(hf_attn_layer, pt_attn_layer): + hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight + hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight + hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight + + hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight + hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias + + def _copy_linear(hf_linear, pt_linear): + hf_linear.weight = pt_linear.weight + hf_linear.bias = pt_linear.bias + + def _copy_layer(hf_layer, pt_layer): + # copy layer norms + _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) + _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) + + # copy attn + _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) + + # copy MLP + pt_mlp = pt_layer[1][1] + _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) + _copy_linear(hf_layer.fc2, pt_mlp.net[2]) + + def _copy_layers(hf_layers, pt_layers): + for i, hf_layer in enumerate(hf_layers): + if i != 0: + i += i + pt_layer = pt_layers[i : i + 2] + _copy_layer(hf_layer, pt_layer) + + hf_model = LDMBertModel(config).eval() + + # copy embeds + hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight + hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight + + # copy layer norm + _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) + + # copy hidden layers + _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) + + _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) + + return hf_model + + +def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None): + if text_encoder is None: + config_name = "openai/clip-vit-large-patch14" + try: + config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: 'openai/clip-vit-large-patch14'." + ) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + text_model = CLIPTextModel(config) + else: + text_model = text_encoder + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"] + + for key in keys: + for prefix in remove_prefixes: + if key.startswith(prefix): + text_model_dict[key[len(prefix + ".") :]] = checkpoint[key] + + if is_accelerate_available(): + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device(text_model, param_name, "cpu", value=param) + else: + if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)): + text_model_dict.pop("text_model.embeddings.position_ids", None) + + text_model.load_state_dict(text_model_dict) + + return text_model + + +def create_unet_model(original_config, checkpoint, checkpoint_path_or_dict, model_type, image_size, **kwargs): + extract_ema = kwargs.get("extract_ema", False) + + unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else "" + diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, unet_config, path=path, extract_ema=extract_ema + ) + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + unet = UNet2DConditionModel(**unet_config) + + if is_accelerate_available(): + for param_name, param in diffusers_format_unet_checkpoint.items(): + set_module_tensor_to_device(unet, param_name, "cpu", value=param) + else: + unet.load_state_dict(diffusers_format_unet_checkpoint) + + return unet + + +def create_vae_model(original_config, checkpoint, checkpoint_path_or_dict, model_type, image_size, **kwargs): + vae_config = create_vae_diffusers_config(original_config, image_size=image_size) + path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else "" + diffusers_format_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + vae = AutoencoderKL(**vae_config) + + if is_accelerate_available(): + for param_name, param in diffusers_format_vae_checkpoint.items(): + set_module_tensor_to_device(vae, param_name, "cpu", value=param) + else: + vae.load_state_dict(diffusers_format_vae_checkpoint) + + return vae + + + def download_from_original_stable_diffusion_ckpt( checkpoint_path_or_dict: Union[str, Dict[str, torch.Tensor]], original_config_file: str = None, @@ -737,6 +963,7 @@ def download_from_original_stable_diffusion_ckpt( tokenizer=None, tokenizer_2=None, config_files=None, + **kwargs ) -> DiffusionPipeline: """ Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` @@ -837,15 +1064,10 @@ def download_from_original_stable_diffusion_ckpt( checkpoint = checkpoint["state_dict"] original_config = fetch_original_config(checkpoint, config_files) - model_type = get_model_type(original_config, model_type) + model_type = set_model_type(original_config, model_type) - unet_config = create_unet_diffusers_config(original_config, image_size=image_size) - path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else "" - diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint( - checkpoint, unet_config, path=path, extract_ema=extract_ema - ) - - num_channels = get_num_channels() + unet = create_unet_model(original_config, checkpoint, checkpoint_path_or_dict, model_type, image_size, **kwargs) + vae = create_vae_model(original_config, checkpoint, checkpoint_path_or_dict, model_type, image_size, **kwargs) if pipeline_class is None: # Check if we have a SDXL or SD model and initialize default pipeline