From fee93c81eb7c5e9fe1618f858f1e369567170edc Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 23 Jan 2024 14:42:03 +0530 Subject: [PATCH] [Refactor] Update from single file (#6428) * update * update * update * update * update * update * update * update * update * update * update' * update * update * update * update * update * update * up * update * update * update * update * update * update * update * update * update * update * update * update * up * update * update * update * update * update' * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * clean * update * update * clean up * clean up * update * clean * clean * update * updaet * clean up * fix docs * update * update * Revert "update" This reverts commit dbfb8f1ea9c61a2b4e02f926245be2b3d387e577. * update * update * update * update * fix controlnet * fix scheduler * fix controlnet tests --- docs/source/en/api/loaders/single_file.md | 4 +- src/diffusers/loaders/__init__.py | 10 +- src/diffusers/loaders/autoencoder.py | 126 ++ src/diffusers/loaders/controlnet.py | 127 ++ src/diffusers/loaders/single_file.py | 696 +++------ src/diffusers/loaders/single_file_utils.py | 1387 +++++++++++++++++ .../autoencoder_kl_temporal_decoder.py | 3 +- src/diffusers/models/controlnet.py | 4 +- src/diffusers/models/modeling_utils.py | 8 +- src/diffusers/pipelines/pipeline_utils.py | 9 +- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/constants.py | 1 + src/diffusers/utils/hub_utils.py | 18 +- tests/pipelines/controlnet/test_controlnet.py | 49 +- .../controlnet/test_controlnet_img2img.py | 56 +- .../controlnet/test_controlnet_inpaint.py | 4 +- .../controlnet/test_controlnet_sdxl.py | 44 +- .../stable_diffusion/test_stable_diffusion.py | 12 +- .../test_stable_diffusion_inpaint.py | 5 +- .../test_stable_diffusion_xl.py | 33 + .../test_stable_diffusion_xl_adapter.py | 37 + .../test_stable_diffusion_xl_img2img.py | 46 + 22 files changed, 2082 insertions(+), 598 deletions(-) create mode 100644 src/diffusers/loaders/autoencoder.py create mode 100644 src/diffusers/loaders/controlnet.py create mode 100644 src/diffusers/loaders/single_file_utils.py diff --git a/docs/source/en/api/loaders/single_file.md b/docs/source/en/api/loaders/single_file.md index 52e4460645..62dbc21067 100644 --- a/docs/source/en/api/loaders/single_file.md +++ b/docs/source/en/api/loaders/single_file.md @@ -30,8 +30,8 @@ To learn more about how to load single file weights, see the [Load different Sta ## FromOriginalVAEMixin -[[autodoc]] loaders.single_file.FromOriginalVAEMixin +[[autodoc]] loaders.autoencoder.FromOriginalVAEMixin ## FromOriginalControlnetMixin -[[autodoc]] loaders.single_file.FromOriginalControlnetMixin \ No newline at end of file +[[autodoc]] loaders.controlnet.FromOriginalControlNetMixin \ No newline at end of file diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index d7855206a2..4da047435d 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -54,12 +54,13 @@ if is_transformers_available(): _import_structure = {} if is_torch_available(): - _import_structure["single_file"] = ["FromOriginalControlnetMixin", "FromOriginalVAEMixin"] + _import_structure["autoencoder"] = ["FromOriginalVAEMixin"] + + _import_structure["controlnet"] = ["FromOriginalControlNetMixin"] _import_structure["unet"] = ["UNet2DConditionLoadersMixin"] _import_structure["utils"] = ["AttnProcsLayers"] - if is_transformers_available(): - _import_structure["single_file"].extend(["FromSingleFileMixin"]) + _import_structure["single_file"] = ["FromSingleFileMixin"] _import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin"] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] _import_structure["ip_adapter"] = ["IPAdapterMixin"] @@ -69,7 +70,8 @@ _import_structure["peft"] = ["PeftAdapterMixin"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): - from .single_file import FromOriginalControlnetMixin, FromOriginalVAEMixin + from .autoencoder import FromOriginalVAEMixin + from .controlnet import FromOriginalControlNetMixin from .unet import UNet2DConditionLoadersMixin from .utils import AttnProcsLayers diff --git a/src/diffusers/loaders/autoencoder.py b/src/diffusers/loaders/autoencoder.py new file mode 100644 index 0000000000..6e65bd1c00 --- /dev/null +++ b/src/diffusers/loaders/autoencoder.py @@ -0,0 +1,126 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from huggingface_hub.utils import validate_hf_hub_args + +from .single_file_utils import ( + create_diffusers_vae_model_from_ldm, + fetch_ldm_config_and_checkpoint, +) + + +class FromOriginalVAEMixin: + """ + Load pretrained AutoencoderKL weights saved in the `.ckpt` or `.safetensors` format into a [`AutoencoderKL`]. + """ + + @classmethod + @validate_hf_hub_args + def from_single_file(cls, pretrained_model_link_or_path, **kwargs): + r""" + Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or + `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default. + + Parameters: + pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + - A link to the `.ckpt` file (for example + `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. + - A path to a *file* containing all pipeline weights. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to True, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + image_size (`int`, *optional*, defaults to 512): + The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable + Diffusion v2 base model. Use 768 for Stable Diffusion v2. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (for example the pipeline components of the + specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` + method. See example below for more information. + + + + Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you're loading + a VAE from SDXL or a Stable Diffusion v2 model or higher. + + + + Examples: + + ```py + from diffusers import AutoencoderKL + + url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file + model = AutoencoderKL.from_single_file(url) + ``` + """ + + original_config_file = kwargs.pop("original_config_file", None) + resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + cache_dir = kwargs.pop("cache_dir", None) + local_files_only = kwargs.pop("local_files_only", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + use_safetensors = kwargs.pop("use_safetensors", True) + + class_name = cls.__name__ + original_config, checkpoint = fetch_ldm_config_and_checkpoint( + pretrained_model_link_or_path=pretrained_model_link_or_path, + class_name=class_name, + original_config_file=original_config_file, + resume_download=resume_download, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + local_files_only=local_files_only, + use_safetensors=use_safetensors, + cache_dir=cache_dir, + ) + + image_size = kwargs.pop("image_size", None) + component = create_diffusers_vae_model_from_ldm(class_name, original_config, checkpoint, image_size=image_size) + vae = component["vae"] + if torch_dtype is not None: + vae = vae.to(torch_dtype) + + return vae diff --git a/src/diffusers/loaders/controlnet.py b/src/diffusers/loaders/controlnet.py new file mode 100644 index 0000000000..527a77109a --- /dev/null +++ b/src/diffusers/loaders/controlnet.py @@ -0,0 +1,127 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from huggingface_hub.utils import validate_hf_hub_args + +from .single_file_utils import ( + create_diffusers_controlnet_model_from_ldm, + fetch_ldm_config_and_checkpoint, +) + + +class FromOriginalControlNetMixin: + """ + Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`]. + """ + + @classmethod + @validate_hf_hub_args + def from_single_file(cls, pretrained_model_link_or_path, **kwargs): + r""" + Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or + `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default. + + Parameters: + pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + - A link to the `.ckpt` file (for example + `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. + - A path to a *file* containing all pipeline weights. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to True, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + image_size (`int`, *optional*, defaults to 512): + The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable + Diffusion v2 base model. Use 768 for Stable Diffusion v2. + upcast_attention (`bool`, *optional*, defaults to `None`): + Whether the attention computation should always be upcasted. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (for example the pipeline components of the + specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` + method. See example below for more information. + + Examples: + + ```py + from diffusers import StableDiffusionControlNetPipeline, ControlNetModel + + url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path + model = ControlNetModel.from_single_file(url) + + url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path + pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet) + ``` + """ + original_config_file = kwargs.pop("original_config_file", None) + resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + cache_dir = kwargs.pop("cache_dir", None) + local_files_only = kwargs.pop("local_files_only", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + use_safetensors = kwargs.pop("use_safetensors", True) + + class_name = cls.__name__ + original_config, checkpoint = fetch_ldm_config_and_checkpoint( + pretrained_model_link_or_path=pretrained_model_link_or_path, + class_name=class_name, + original_config_file=original_config_file, + resume_download=resume_download, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + local_files_only=local_files_only, + use_safetensors=use_safetensors, + cache_dir=cache_dir, + ) + + upcast_attention = kwargs.pop("upcast_attention", False) + image_size = kwargs.pop("image_size", None) + + component = create_diffusers_controlnet_model_from_ldm( + class_name, original_config, checkpoint, upcast_attention=upcast_attention, image_size=image_size + ) + controlnet = component["controlnet"] + if torch_dtype is not None: + controlnet = controlnet.to(torch_dtype) + + return controlnet diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index 4086b1a2a8..034271aaba 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -11,39 +11,132 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext -from io import BytesIO -from pathlib import Path -import requests -import torch -import yaml -from huggingface_hub import hf_hub_download from huggingface_hub.utils import validate_hf_hub_args -from ..utils import deprecate, is_accelerate_available, is_transformers_available, logging +from ..utils import is_transformers_available, logging +from .single_file_utils import ( + create_diffusers_unet_model_from_ldm, + create_diffusers_vae_model_from_ldm, + create_scheduler_from_ldm, + create_text_encoders_and_tokenizers_from_ldm, + fetch_ldm_config_and_checkpoint, + infer_model_type, +) -if is_transformers_available(): - pass - -if is_accelerate_available(): - from accelerate import init_empty_weights - logger = logging.get_logger(__name__) +# Pipelines that support the SDXL Refiner checkpoint +REFINER_PIPELINES = [ + "StableDiffusionXLImg2ImgPipeline", + "StableDiffusionXLInpaintPipeline", + "StableDiffusionXLControlNetImg2ImgPipeline", +] + +if is_transformers_available(): + from transformers import AutoFeatureExtractor + + +def build_sub_model_components( + pipeline_components, + pipeline_class_name, + component_name, + original_config, + checkpoint, + local_files_only=False, + load_safety_checker=False, + model_type=None, + image_size=None, + **kwargs, +): + if component_name in pipeline_components: + return {} + + if component_name == "unet": + num_in_channels = kwargs.pop("num_in_channels", None) + unet_components = create_diffusers_unet_model_from_ldm( + pipeline_class_name, original_config, checkpoint, num_in_channels=num_in_channels, image_size=image_size + ) + return unet_components + + if component_name == "vae": + vae_components = create_diffusers_vae_model_from_ldm( + pipeline_class_name, original_config, checkpoint, image_size + ) + return vae_components + + if component_name == "scheduler": + scheduler_type = kwargs.get("scheduler_type", "ddim") + prediction_type = kwargs.get("prediction_type", None) + + scheduler_components = create_scheduler_from_ldm( + pipeline_class_name, + original_config, + checkpoint, + scheduler_type=scheduler_type, + prediction_type=prediction_type, + model_type=model_type, + ) + + return scheduler_components + + if component_name in ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]: + text_encoder_components = create_text_encoders_and_tokenizers_from_ldm( + original_config, + checkpoint, + model_type=model_type, + local_files_only=local_files_only, + ) + return text_encoder_components + + if component_name == "safety_checker": + if load_safety_checker: + from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + safety_checker = StableDiffusionSafetyChecker.from_pretrained( + "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only + ) + else: + safety_checker = None + return {"safety_checker": safety_checker} + + if component_name == "feature_extractor": + if load_safety_checker: + feature_extractor = AutoFeatureExtractor.from_pretrained( + "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only + ) + else: + feature_extractor = None + return {"feature_extractor": feature_extractor} + + return + + +def set_additional_components( + pipeline_class_name, + original_config, + model_type=None, +): + components = {} + if pipeline_class_name in REFINER_PIPELINES: + model_type = infer_model_type(original_config, model_type=model_type) + is_refiner = model_type == "SDXL-Refiner" + components.update( + { + "requires_aesthetics_score": is_refiner, + "force_zeros_for_empty_prompt": False if is_refiner else True, + } + ) + + return components + class FromSingleFileMixin: """ Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`]. """ - @classmethod - def from_ckpt(cls, *args, **kwargs): - deprecation_message = "The function `from_ckpt` is deprecated in favor of `from_single_file` and will be removed in diffusers v.0.21. Please make sure to use `StableDiffusionPipeline.from_single_file(...)` instead." - deprecate("from_ckpt", "0.21.0", deprecation_message, standard_warn=False) - return cls.from_single_file(*args, **kwargs) - @classmethod @validate_hf_hub_args def from_single_file(cls, pretrained_model_link_or_path, **kwargs): @@ -58,8 +151,7 @@ class FromSingleFileMixin: `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. - A path to a *file* containing all pipeline weights. torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the - dtype is automatically derived from the model's weights. + Override the default `torch.dtype` and load the model with another dtype. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. @@ -85,42 +177,6 @@ class FromSingleFileMixin: If set to `None`, the safetensors weights are downloaded if they're available **and** if the safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors weights. If set to `False`, safetensors weights are not loaded. - extract_ema (`bool`, *optional*, defaults to `False`): - Whether to extract the EMA weights or not. Pass `True` to extract the EMA weights which usually yield - higher quality images for inference. Non-EMA weights are usually better for continuing finetuning. - upcast_attention (`bool`, *optional*, defaults to `None`): - Whether the attention computation should always be upcasted. - image_size (`int`, *optional*, defaults to 512): - The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable - Diffusion v2 base model. Use 768 for Stable Diffusion v2. - prediction_type (`str`, *optional*): - The prediction type the model was trained on. Use `'epsilon'` for all Stable Diffusion v1 models and - the Stable Diffusion v2 base model. Use `'v_prediction'` for Stable Diffusion v2. - num_in_channels (`int`, *optional*, defaults to `None`): - The number of input channels. If `None`, it is automatically inferred. - scheduler_type (`str`, *optional*, defaults to `"pndm"`): - Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", - "ddim"]`. - load_safety_checker (`bool`, *optional*, defaults to `True`): - Whether to load the safety checker or not. - text_encoder ([`~transformers.CLIPTextModel`], *optional*, defaults to `None`): - An instance of `CLIPTextModel` to use, specifically the - [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. If this - parameter is `None`, the function loads a new instance of `CLIPTextModel` by itself if needed. - vae (`AutoencoderKL`, *optional*, defaults to `None`): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. If - this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed. - tokenizer ([`~transformers.CLIPTokenizer`], *optional*, defaults to `None`): - An instance of `CLIPTokenizer` to use. If this parameter is `None`, the function loads a new instance - of `CLIPTokenizer` by itself if needed. - original_config_file (`str`): - Path to `.yaml` config file corresponding to the original architecture. If `None`, will be - automatically inferred by looking for a key that only exists in SD2.0 models. - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to overwrite load and saveable variables (for example the pipeline components of the - specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` - method. See example below for more information. - Examples: ```py @@ -143,484 +199,80 @@ class FromSingleFileMixin: >>> pipeline.to("cuda") ``` """ - # import here to avoid circular dependency - from ..pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt - original_config_file = kwargs.pop("original_config_file", None) - config_files = kwargs.pop("config_files", None) - cache_dir = kwargs.pop("cache_dir", None) resume_download = kwargs.pop("resume_download", False) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) + cache_dir = kwargs.pop("cache_dir", None) + local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) - extract_ema = kwargs.pop("extract_ema", False) - image_size = kwargs.pop("image_size", None) - scheduler_type = kwargs.pop("scheduler_type", "pndm") - num_in_channels = kwargs.pop("num_in_channels", None) - upcast_attention = kwargs.pop("upcast_attention", None) - load_safety_checker = kwargs.pop("load_safety_checker", True) - prediction_type = kwargs.pop("prediction_type", None) - text_encoder = kwargs.pop("text_encoder", None) - text_encoder_2 = kwargs.pop("text_encoder_2", None) - vae = kwargs.pop("vae", None) - controlnet = kwargs.pop("controlnet", None) - adapter = kwargs.pop("adapter", None) - tokenizer = kwargs.pop("tokenizer", None) - tokenizer_2 = kwargs.pop("tokenizer_2", None) - torch_dtype = kwargs.pop("torch_dtype", None) + use_safetensors = kwargs.pop("use_safetensors", True) - use_safetensors = kwargs.pop("use_safetensors", None) + class_name = cls.__name__ - pipeline_name = cls.__name__ - file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] - from_safetensors = file_extension == "safetensors" - - if from_safetensors and use_safetensors is False: - raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.") - - # TODO: For now we only support stable diffusion - stable_unclip = None - model_type = None - - if pipeline_name in [ - "StableDiffusionControlNetPipeline", - "StableDiffusionControlNetImg2ImgPipeline", - "StableDiffusionControlNetInpaintPipeline", - ]: - from ..models.controlnet import ControlNetModel - from ..pipelines.controlnet.multicontrolnet import MultiControlNetModel - - # list/tuple or a single instance of ControlNetModel or MultiControlNetModel - if not ( - isinstance(controlnet, (ControlNetModel, MultiControlNetModel)) - or isinstance(controlnet, (list, tuple)) - and isinstance(controlnet[0], ControlNetModel) - ): - raise ValueError("ControlNet needs to be passed if loading from ControlNet pipeline.") - elif "StableDiffusion" in pipeline_name: - # Model type will be inferred from the checkpoint. - pass - elif pipeline_name == "StableUnCLIPPipeline": - model_type = "FrozenOpenCLIPEmbedder" - stable_unclip = "txt2img" - elif pipeline_name == "StableUnCLIPImg2ImgPipeline": - model_type = "FrozenOpenCLIPEmbedder" - stable_unclip = "img2img" - elif pipeline_name == "PaintByExamplePipeline": - model_type = "PaintByExample" - elif pipeline_name == "LDMTextToImagePipeline": - model_type = "LDMTextToImage" - else: - raise ValueError(f"Unhandled pipeline class: {pipeline_name}") - - # remove huggingface url - has_valid_url_prefix = False - valid_url_prefixes = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"] - for prefix in valid_url_prefixes: - if pretrained_model_link_or_path.startswith(prefix): - pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :] - has_valid_url_prefix = True - - # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained - ckpt_path = Path(pretrained_model_link_or_path) - if not ckpt_path.is_file(): - if not has_valid_url_prefix: - raise ValueError( - f"The provided path is either not a file or a valid huggingface URL was not provided. Valid URLs begin with {', '.join(valid_url_prefixes)}" - ) - - # get repo_id and (potentially nested) file path of ckpt in repo - repo_id = "/".join(ckpt_path.parts[:2]) - file_path = "/".join(ckpt_path.parts[2:]) - - if file_path.startswith("blob/"): - file_path = file_path[len("blob/") :] - - if file_path.startswith("main/"): - file_path = file_path[len("main/") :] - - pretrained_model_link_or_path = hf_hub_download( - repo_id, - filename=file_path, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - force_download=force_download, - ) - - pipe = download_from_original_stable_diffusion_ckpt( - pretrained_model_link_or_path, - pipeline_class=cls, - model_type=model_type, - stable_unclip=stable_unclip, - controlnet=controlnet, - adapter=adapter, - from_safetensors=from_safetensors, - extract_ema=extract_ema, - image_size=image_size, - scheduler_type=scheduler_type, - num_in_channels=num_in_channels, - upcast_attention=upcast_attention, - load_safety_checker=load_safety_checker, - prediction_type=prediction_type, - text_encoder=text_encoder, - text_encoder_2=text_encoder_2, - vae=vae, - tokenizer=tokenizer, - tokenizer_2=tokenizer_2, + original_config, checkpoint = fetch_ldm_config_and_checkpoint( + pretrained_model_link_or_path=pretrained_model_link_or_path, + class_name=class_name, original_config_file=original_config_file, - config_files=config_files, + resume_download=resume_download, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, local_files_only=local_files_only, + use_safetensors=use_safetensors, + cache_dir=cache_dir, ) + from ..pipelines.pipeline_utils import _get_pipeline_class + + pipeline_class = _get_pipeline_class( + cls, + config=None, + cache_dir=cache_dir, + ) + + expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} + + model_type = kwargs.pop("model_type", None) + image_size = kwargs.pop("image_size", None) + load_safety_checker = (kwargs.pop("load_safety_checker", False)) or ( + passed_class_obj.get("safety_checker", None) is not None + ) + + init_kwargs = {} + for name in expected_modules: + if name in passed_class_obj: + init_kwargs[name] = passed_class_obj[name] + else: + components = build_sub_model_components( + init_kwargs, + class_name, + name, + original_config, + checkpoint, + model_type=model_type, + image_size=image_size, + load_safety_checker=load_safety_checker, + local_files_only=local_files_only, + **kwargs, + ) + if not components: + continue + init_kwargs.update(components) + + additional_components = set_additional_components(class_name, original_config, model_type=model_type) + if additional_components: + init_kwargs.update(additional_components) + + init_kwargs.update(passed_pipe_kwargs) + pipe = pipeline_class(**init_kwargs) + if torch_dtype is not None: pipe.to(dtype=torch_dtype) return pipe - - -class FromOriginalVAEMixin: - """ - Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into an [`AutoencoderKL`]. - """ - - @classmethod - @validate_hf_hub_args - def from_single_file(cls, pretrained_model_link_or_path, **kwargs): - r""" - Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or - `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default. - - Parameters: - pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): - Can be either: - - A link to the `.ckpt` file (for example - `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. - - A path to a *file* containing all pipeline weights. - torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the - dtype is automatically derived from the model's weights. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to resume downloading the model weights and configuration files. If set to `False`, any - incompletely downloaded files are deleted. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to True, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - image_size (`int`, *optional*, defaults to 512): - The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable - Diffusion v2 base model. Use 768 for Stable Diffusion v2. - use_safetensors (`bool`, *optional*, defaults to `None`): - If set to `None`, the safetensors weights are downloaded if they're available **and** if the - safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors - weights. If set to `False`, safetensors weights are not loaded. - upcast_attention (`bool`, *optional*, defaults to `None`): - Whether the attention computation should always be upcasted. - scaling_factor (`float`, *optional*, defaults to 0.18215): - The component-wise standard deviation of the trained latent space computed using the first batch of the - training set. This is used to scale the latent space to have unit variance when training the diffusion - model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the - diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z - = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution - Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to overwrite load and saveable variables (for example the pipeline components of the - specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` - method. See example below for more information. - - - - Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you're loading - a VAE from SDXL or a Stable Diffusion v2 model or higher. - - - - Examples: - - ```py - from diffusers import AutoencoderKL - - url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file - model = AutoencoderKL.from_single_file(url) - ``` - """ - from ..models import AutoencoderKL - - # import here to avoid circular dependency - from ..pipelines.stable_diffusion.convert_from_ckpt import ( - convert_ldm_vae_checkpoint, - create_vae_diffusers_config, - ) - - config_file = kwargs.pop("config_file", None) - cache_dir = kwargs.pop("cache_dir", None) - resume_download = kwargs.pop("resume_download", False) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - image_size = kwargs.pop("image_size", None) - scaling_factor = kwargs.pop("scaling_factor", None) - kwargs.pop("upcast_attention", None) - - torch_dtype = kwargs.pop("torch_dtype", None) - - use_safetensors = kwargs.pop("use_safetensors", None) - - file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] - from_safetensors = file_extension == "safetensors" - - if from_safetensors and use_safetensors is False: - raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.") - - # remove huggingface url - for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]: - if pretrained_model_link_or_path.startswith(prefix): - pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :] - - # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained - ckpt_path = Path(pretrained_model_link_or_path) - if not ckpt_path.is_file(): - # get repo_id and (potentially nested) file path of ckpt in repo - repo_id = "/".join(ckpt_path.parts[:2]) - file_path = "/".join(ckpt_path.parts[2:]) - - if file_path.startswith("blob/"): - file_path = file_path[len("blob/") :] - - if file_path.startswith("main/"): - file_path = file_path[len("main/") :] - - pretrained_model_link_or_path = hf_hub_download( - repo_id, - filename=file_path, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - force_download=force_download, - ) - - if from_safetensors: - from safetensors import safe_open - - checkpoint = {} - with safe_open(pretrained_model_link_or_path, framework="pt", device="cpu") as f: - for key in f.keys(): - checkpoint[key] = f.get_tensor(key) - else: - checkpoint = torch.load(pretrained_model_link_or_path, map_location="cpu") - - if "state_dict" in checkpoint: - checkpoint = checkpoint["state_dict"] - - if config_file is None: - config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" - config_file = BytesIO(requests.get(config_url).content) - - original_config = yaml.safe_load(config_file) - - # default to sd-v1-5 - image_size = image_size or 512 - - vae_config = create_vae_diffusers_config(original_config, image_size=image_size) - converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) - - if scaling_factor is None: - if ( - "model" in original_config - and "params" in original_config["model"] - and "scale_factor" in original_config["model"]["params"] - ): - vae_scaling_factor = original_config["model"]["params"]["scale_factor"] - else: - vae_scaling_factor = 0.18215 # default SD scaling factor - - vae_config["scaling_factor"] = vae_scaling_factor - - ctx = init_empty_weights if is_accelerate_available() else nullcontext - with ctx(): - vae = AutoencoderKL(**vae_config) - - if is_accelerate_available(): - from ..models.modeling_utils import load_model_dict_into_meta - - load_model_dict_into_meta(vae, converted_vae_checkpoint, device="cpu") - else: - vae.load_state_dict(converted_vae_checkpoint) - - if torch_dtype is not None: - vae.to(dtype=torch_dtype) - - return vae - - -class FromOriginalControlnetMixin: - """ - Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`]. - """ - - @classmethod - @validate_hf_hub_args - def from_single_file(cls, pretrained_model_link_or_path, **kwargs): - r""" - Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or - `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default. - - Parameters: - pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): - Can be either: - - A link to the `.ckpt` file (for example - `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. - - A path to a *file* containing all pipeline weights. - torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the - dtype is automatically derived from the model's weights. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to resume downloading the model weights and configuration files. If set to `False`, any - incompletely downloaded files are deleted. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to True, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - use_safetensors (`bool`, *optional*, defaults to `None`): - If set to `None`, the safetensors weights are downloaded if they're available **and** if the - safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors - weights. If set to `False`, safetensors weights are not loaded. - image_size (`int`, *optional*, defaults to 512): - The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable - Diffusion v2 base model. Use 768 for Stable Diffusion v2. - upcast_attention (`bool`, *optional*, defaults to `None`): - Whether the attention computation should always be upcasted. - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to overwrite load and saveable variables (for example the pipeline components of the - specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` - method. See example below for more information. - - Examples: - - ```py - from diffusers import StableDiffusionControlNetPipeline, ControlNetModel - - url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path - model = ControlNetModel.from_single_file(url) - - url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path - pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet) - ``` - """ - # import here to avoid circular dependency - from ..pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt - - config_file = kwargs.pop("config_file", None) - cache_dir = kwargs.pop("cache_dir", None) - resume_download = kwargs.pop("resume_download", False) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - num_in_channels = kwargs.pop("num_in_channels", None) - use_linear_projection = kwargs.pop("use_linear_projection", None) - revision = kwargs.pop("revision", None) - extract_ema = kwargs.pop("extract_ema", False) - image_size = kwargs.pop("image_size", None) - upcast_attention = kwargs.pop("upcast_attention", None) - - torch_dtype = kwargs.pop("torch_dtype", None) - - use_safetensors = kwargs.pop("use_safetensors", None) - - file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] - from_safetensors = file_extension == "safetensors" - - if from_safetensors and use_safetensors is False: - raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.") - - # remove huggingface url - for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]: - if pretrained_model_link_or_path.startswith(prefix): - pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :] - - # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained - ckpt_path = Path(pretrained_model_link_or_path) - if not ckpt_path.is_file(): - # get repo_id and (potentially nested) file path of ckpt in repo - repo_id = "/".join(ckpt_path.parts[:2]) - file_path = "/".join(ckpt_path.parts[2:]) - - if file_path.startswith("blob/"): - file_path = file_path[len("blob/") :] - - if file_path.startswith("main/"): - file_path = file_path[len("main/") :] - - pretrained_model_link_or_path = hf_hub_download( - repo_id, - filename=file_path, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - force_download=force_download, - ) - - if config_file is None: - config_url = "https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml" - config_file = BytesIO(requests.get(config_url).content) - - image_size = image_size or 512 - - controlnet = download_controlnet_from_original_ckpt( - pretrained_model_link_or_path, - original_config_file=config_file, - image_size=image_size, - extract_ema=extract_ema, - num_in_channels=num_in_channels, - upcast_attention=upcast_attention, - from_safetensors=from_safetensors, - use_linear_projection=use_linear_projection, - ) - - if torch_dtype is not None: - controlnet.to(dtype=torch_dtype) - - return controlnet diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py new file mode 100644 index 0000000000..fb1ad14fd3 --- /dev/null +++ b/src/diffusers/loaders/single_file_utils.py @@ -0,0 +1,1387 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Conversion script for the Stable Diffusion checkpoints.""" + +import os +import re +from contextlib import nullcontext +from io import BytesIO +from urllib.parse import urlparse + +import requests +import yaml + +from ..models.modeling_utils import load_state_dict +from ..schedulers import ( + DDIMScheduler, + DDPMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ..utils import is_accelerate_available, is_transformers_available, logging +from ..utils.hub_utils import _get_model_file + + +if is_transformers_available(): + from transformers import ( + CLIPTextConfig, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + ) + +if is_accelerate_available(): + from accelerate import init_empty_weights + from accelerate.utils import set_module_tensor_to_device + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +CONFIG_URLS = { + "v1": "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml", + "v2": "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml", + "xl": "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml", + "xl_refiner": "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml", + "upscale": "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml", + "controlnet": "https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml", +} + +CHECKPOINT_KEY_NAMES = { + "v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", + "xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias", + "xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias", +} + +SCHEDULER_DEFAULT_CONFIG = { + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "beta_end": 0.012, + "interpolation_type": "linear", + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "sample_max_value": 1.0, + "set_alpha_to_one": False, + "skip_prk_steps": True, + "steps_offset": 1, + "timestep_spacing": "leading", +} + +DIFFUSERS_TO_LDM_MAPPING = { + "unet": { + "layers": { + "time_embedding.linear_1.weight": "time_embed.0.weight", + "time_embedding.linear_1.bias": "time_embed.0.bias", + "time_embedding.linear_2.weight": "time_embed.2.weight", + "time_embedding.linear_2.bias": "time_embed.2.bias", + "conv_in.weight": "input_blocks.0.0.weight", + "conv_in.bias": "input_blocks.0.0.bias", + "conv_norm_out.weight": "out.0.weight", + "conv_norm_out.bias": "out.0.bias", + "conv_out.weight": "out.2.weight", + "conv_out.bias": "out.2.bias", + }, + "class_embed_type": { + "class_embedding.linear_1.weight": "label_emb.0.0.weight", + "class_embedding.linear_1.bias": "label_emb.0.0.bias", + "class_embedding.linear_2.weight": "label_emb.0.2.weight", + "class_embedding.linear_2.bias": "label_emb.0.2.bias", + }, + "addition_embed_type": { + "add_embedding.linear_1.weight": "label_emb.0.0.weight", + "add_embedding.linear_1.bias": "label_emb.0.0.bias", + "add_embedding.linear_2.weight": "label_emb.0.2.weight", + "add_embedding.linear_2.bias": "label_emb.0.2.bias", + }, + }, + "controlnet": { + "layers": { + "time_embedding.linear_1.weight": "time_embed.0.weight", + "time_embedding.linear_1.bias": "time_embed.0.bias", + "time_embedding.linear_2.weight": "time_embed.2.weight", + "time_embedding.linear_2.bias": "time_embed.2.bias", + "conv_in.weight": "input_blocks.0.0.weight", + "conv_in.bias": "input_blocks.0.0.bias", + "controlnet_cond_embedding.conv_in.weight": "input_hint_block.0.weight", + "controlnet_cond_embedding.conv_in.bias": "input_hint_block.0.bias", + "controlnet_cond_embedding.conv_out.weight": "input_hint_block.14.weight", + "controlnet_cond_embedding.conv_out.bias": "input_hint_block.14.bias", + }, + "class_embed_type": { + "class_embedding.linear_1.weight": "label_emb.0.0.weight", + "class_embedding.linear_1.bias": "label_emb.0.0.bias", + "class_embedding.linear_2.weight": "label_emb.0.2.weight", + "class_embedding.linear_2.bias": "label_emb.0.2.bias", + }, + "addition_embed_type": { + "add_embedding.linear_1.weight": "label_emb.0.0.weight", + "add_embedding.linear_1.bias": "label_emb.0.0.bias", + "add_embedding.linear_2.weight": "label_emb.0.2.weight", + "add_embedding.linear_2.bias": "label_emb.0.2.bias", + }, + }, + "vae": { + "encoder.conv_in.weight": "encoder.conv_in.weight", + "encoder.conv_in.bias": "encoder.conv_in.bias", + "encoder.conv_out.weight": "encoder.conv_out.weight", + "encoder.conv_out.bias": "encoder.conv_out.bias", + "encoder.conv_norm_out.weight": "encoder.norm_out.weight", + "encoder.conv_norm_out.bias": "encoder.norm_out.bias", + "decoder.conv_in.weight": "decoder.conv_in.weight", + "decoder.conv_in.bias": "decoder.conv_in.bias", + "decoder.conv_out.weight": "decoder.conv_out.weight", + "decoder.conv_out.bias": "decoder.conv_out.bias", + "decoder.conv_norm_out.weight": "decoder.norm_out.weight", + "decoder.conv_norm_out.bias": "decoder.norm_out.bias", + "quant_conv.weight": "quant_conv.weight", + "quant_conv.bias": "quant_conv.bias", + "post_quant_conv.weight": "post_quant_conv.weight", + "post_quant_conv.bias": "post_quant_conv.bias", + }, + "openclip": { + "layers": { + "text_model.embeddings.position_embedding.weight": "positional_embedding", + "text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "text_model.final_layer_norm.weight": "ln_final.weight", + "text_model.final_layer_norm.bias": "ln_final.bias", + "text_projection.weight": "text_projection", + }, + "transformer": { + "text_model.encoder.layers.": "resblocks.", + "layer_norm1": "ln_1", + "layer_norm2": "ln_2", + ".fc1.": ".c_fc.", + ".fc2.": ".c_proj.", + ".self_attn": ".attn", + "transformer.text_model.final_layer_norm.": "ln_final.", + "transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "transformer.text_model.embeddings.position_embedding.weight": "positional_embedding", + }, + }, +} + +LDM_VAE_KEY = "first_stage_model." +LDM_UNET_KEY = "model.diffusion_model." +LDM_CONTROLNET_KEY = "control_model." +LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."] +LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024 + +SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [ + "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias", + "cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight", + "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.bias", + "cond_stage_model.model.transformer.resblocks.23.attn.out_proj.weight", + "cond_stage_model.model.transformer.resblocks.23.ln_1.bias", + "cond_stage_model.model.transformer.resblocks.23.ln_1.weight", + "cond_stage_model.model.transformer.resblocks.23.ln_2.bias", + "cond_stage_model.model.transformer.resblocks.23.ln_2.weight", + "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.bias", + "cond_stage_model.model.transformer.resblocks.23.mlp.c_fc.weight", + "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.bias", + "cond_stage_model.model.transformer.resblocks.23.mlp.c_proj.weight", + "cond_stage_model.model.text_projection", +] + + +VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"] + + +def _extract_repo_id_and_weights_name(pretrained_model_name_or_path): + pattern = r"([^/]+)/([^/]+)/(?:blob/main/)?(.+)" + weights_name = None + repo_id = (None,) + for prefix in VALID_URL_PREFIXES: + pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "") + match = re.match(pattern, pretrained_model_name_or_path) + if not match: + return repo_id, weights_name + + repo_id = f"{match.group(1)}/{match.group(2)}" + weights_name = match.group(3) + + return repo_id, weights_name + + +def fetch_ldm_config_and_checkpoint( + pretrained_model_link_or_path, + class_name, + original_config_file=None, + resume_download=False, + force_download=False, + proxies=None, + token=None, + cache_dir=None, + local_files_only=None, + revision=None, + use_safetensors=True, +): + file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] + from_safetensors = file_extension == "safetensors" + + if from_safetensors and use_safetensors is False: + raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.") + + if os.path.isfile(pretrained_model_link_or_path): + checkpoint = load_state_dict(pretrained_model_link_or_path) + + else: + repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path) + checkpoint_path = _get_model_file( + repo_id, + weights_name=weights_name, + force_download=force_download, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + ) + checkpoint = load_state_dict(checkpoint_path) + + # some checkpoints contain the model state dict under a "state_dict" key + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + original_config = fetch_original_config(class_name, checkpoint, original_config_file) + + return original_config, checkpoint + + +def infer_original_config_file(class_name, checkpoint): + if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024: + config_url = CONFIG_URLS["v2"] + + elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint: + config_url = CONFIG_URLS["xl"] + + elif CHECKPOINT_KEY_NAMES["xl_refiner"] in checkpoint: + config_url = CONFIG_URLS["xl_refiner"] + + elif class_name == "StableDiffusionUpscalePipeline": + config_url = CONFIG_URLS["upscale"] + + elif class_name == "ControlNetModel": + config_url = CONFIG_URLS["controlnet"] + + else: + config_url = CONFIG_URLS["v1"] + + original_config_file = BytesIO(requests.get(config_url).content) + + return original_config_file + + +def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=None): + def is_valid_url(url): + result = urlparse(url) + if result.scheme and result.netloc: + return True + + return False + + if original_config_file is None: + original_config_file = infer_original_config_file(pipeline_class_name, checkpoint) + + elif os.path.isfile(original_config_file): + with open(original_config_file, "r") as fp: + original_config_file = fp.read() + + elif is_valid_url(original_config_file): + original_config_file = BytesIO(requests.get(original_config_file).content) + + else: + raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.") + + original_config = yaml.safe_load(original_config_file) + + return original_config + + +def infer_model_type(original_config, model_type=None): + if model_type is not None: + return model_type + + has_cond_stage_config = ( + "cond_stage_config" in original_config["model"]["params"] + and original_config["model"]["params"]["cond_stage_config"] is not None + ) + has_network_config = ( + "network_config" in original_config["model"]["params"] + and original_config["model"]["params"]["network_config"] is not None + ) + + if has_cond_stage_config: + model_type = original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1] + + elif has_network_config: + context_dim = original_config["model"]["params"]["network_config"]["params"]["context_dim"] + if context_dim == 2048: + model_type = "SDXL" + else: + model_type = "SDXL-Refiner" + else: + raise ValueError("Unable to infer model type from config") + + logger.debug(f"No `model_type` given, `model_type` inferred as: {model_type}") + + return model_type + + +def get_default_scheduler_config(): + return SCHEDULER_DEFAULT_CONFIG + + +def set_image_size(pipeline_class_name, original_config, checkpoint, image_size=None, model_type=None): + if image_size: + return image_size + + global_step = checkpoint["global_step"] if "global_step" in checkpoint else None + model_type = infer_model_type(original_config, model_type) + + if pipeline_class_name == "StableDiffusionUpscalePipeline": + image_size = original_config["model"]["params"]["unet_config"]["params"]["image_size"] + return image_size + + elif model_type in ["SDXL", "SDXL-Refiner"]: + image_size = 1024 + return image_size + + elif ( + "parameterization" in original_config["model"]["params"] + and original_config["model"]["params"]["parameterization"] == "v" + ): + # NOTE: For stable diffusion 2 base one has to pass `image_size==512` + # as it relies on a brittle global step parameter here + image_size = 512 if global_step == 875000 else 768 + return image_size + + else: + image_size = 512 + return image_size + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_unet_diffusers_config(original_config, image_size: int): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + if ( + "unet_config" in original_config["model"]["params"] + and original_config["model"]["params"]["unet_config"] is not None + ): + unet_params = original_config["model"]["params"]["unet_config"]["params"] + else: + unet_params = original_config["model"]["params"]["network_config"]["params"] + + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + if unet_params["transformer_depth"] is not None: + transformer_layers_per_block = ( + unet_params["transformer_depth"] + if isinstance(unet_params["transformer_depth"], int) + else list(unet_params["transformer_depth"]) + ) + else: + transformer_layers_per_block = 1 + + vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) + + head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"] + head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])] + + class_embed_type = None + addition_embed_type = None + addition_time_embed_dim = None + projection_class_embeddings_input_dim = None + context_dim = None + + if unet_params["context_dim"] is not None: + context_dim = ( + unet_params["context_dim"] + if isinstance(unet_params["context_dim"], int) + else unet_params["context_dim"][0] + ) + + if "num_classes" in unet_params: + if unet_params["num_classes"] == "sequential": + if context_dim in [2048, 1280]: + # SDXL + addition_embed_type = "text_time" + addition_time_embed_dim = 256 + else: + class_embed_type = "projection" + assert "adm_in_channels" in unet_params + projection_class_embeddings_input_dim = unet_params["adm_in_channels"] + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params["in_channels"], + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params["num_res_blocks"], + "cross_attention_dim": context_dim, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "addition_embed_type": addition_embed_type, + "addition_time_embed_dim": addition_time_embed_dim, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "transformer_layers_per_block": transformer_layers_per_block, + } + + if "disable_self_attentions" in unet_params: + config["only_cross_attention"] = unet_params["disable_self_attentions"] + + if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int): + config["num_class_embeds"] = unet_params["num_classes"] + + config["out_channels"] = unet_params["out_channels"] + config["up_block_types"] = tuple(up_block_types) + + return config + + +def create_controlnet_diffusers_config(original_config, image_size: int): + unet_params = original_config["model"]["params"]["control_stage_config"]["params"] + diffusers_unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + + controlnet_config = { + "conditioning_channels": unet_params["hint_channels"], + "in_channels": diffusers_unet_config["in_channels"], + "down_block_types": diffusers_unet_config["down_block_types"], + "block_out_channels": diffusers_unet_config["block_out_channels"], + "layers_per_block": diffusers_unet_config["layers_per_block"], + "cross_attention_dim": diffusers_unet_config["cross_attention_dim"], + "attention_head_dim": diffusers_unet_config["attention_head_dim"], + "use_linear_projection": diffusers_unet_config["use_linear_projection"], + "class_embed_type": diffusers_unet_config["class_embed_type"], + "addition_embed_type": diffusers_unet_config["addition_embed_type"], + "addition_time_embed_dim": diffusers_unet_config["addition_time_embed_dim"], + "projection_class_embeddings_input_dim": diffusers_unet_config["projection_class_embeddings_input_dim"], + "transformer_layers_per_block": diffusers_unet_config["transformer_layers_per_block"], + } + + return controlnet_config + + +def create_vae_diffusers_config(original_config, image_size, scaling_factor=0.18125): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + + block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = { + "sample_size": image_size, + "in_channels": vae_params["in_channels"], + "out_channels": vae_params["out_ch"], + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params["z_channels"], + "layers_per_block": vae_params["num_res_blocks"], + "scaling_factor": scaling_factor, + } + + return config + + +def update_unet_resnet_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping=None): + for ldm_key in ldm_keys: + diffusers_key = ( + ldm_key.replace("in_layers.0", "norm1") + .replace("in_layers.2", "conv1") + .replace("out_layers.0", "norm2") + .replace("out_layers.3", "conv2") + .replace("emb_layers.1", "time_emb_proj") + .replace("skip_connection", "conv_shortcut") + ) + if mapping: + diffusers_key = diffusers_key.replace(mapping["old"], mapping["new"]) + new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key) + + +def update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping): + for ldm_key in ldm_keys: + diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]) + new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key) + + +def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + unet_key = LDM_UNET_KEY + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + logger.warning("Checkpoint has both EMA and non-EMA weights.") + logger.warning( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + logger.warning( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + ldm_unet_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["layers"] + for diffusers_key, ldm_key in ldm_unet_keys.items(): + if ldm_key not in unet_state_dict: + continue + new_checkpoint[diffusers_key] = unet_state_dict[ldm_key] + + if ("class_embed_type" in config) and (config["class_embed_type"] in ["timestep", "projection"]): + class_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["class_embed_type"] + for diffusers_key, ldm_key in class_embed_keys.items(): + new_checkpoint[diffusers_key] = unet_state_dict[ldm_key] + + if ("addition_embed_type" in config) and (config["addition_embed_type"] == "text_time"): + addition_embed_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["addition_embed_type"] + for diffusers_key, ldm_key in addition_embed_keys.items(): + new_checkpoint[diffusers_key] = unet_state_dict[ldm_key] + + # Relevant to StableDiffusionUpscalePipeline + if "num_class_embeds" in config: + if (config["num_class_embeds"] is not None) and ("label_emb.weight" in unet_state_dict): + new_checkpoint["class_embedding.weight"] = unet_state_dict["label_emb.weight"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + # Down blocks + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + update_unet_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + unet_state_dict, + {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}, + ) + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + if attentions: + update_unet_attention_ldm_to_diffusers( + attentions, + new_checkpoint, + unet_state_dict, + {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}, + ) + + # Mid blocks + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + update_unet_resnet_ldm_to_diffusers( + resnet_0, new_checkpoint, unet_state_dict, mapping={"old": "middle_block.0", "new": "mid_block.resnets.0"} + ) + update_unet_resnet_ldm_to_diffusers( + resnet_1, new_checkpoint, unet_state_dict, mapping={"old": "middle_block.2", "new": "mid_block.resnets.1"} + ) + update_unet_attention_ldm_to_diffusers( + attentions, new_checkpoint, unet_state_dict, mapping={"old": "middle_block.1", "new": "mid_block.attentions.0"} + ) + + # Up Blocks + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + + resnets = [ + key for key in output_blocks[i] if f"output_blocks.{i}.0" in key and f"output_blocks.{i}.0.op" not in key + ] + update_unet_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + unet_state_dict, + {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}, + ) + + attentions = [ + key for key in output_blocks[i] if f"output_blocks.{i}.1" in key and f"output_blocks.{i}.1.conv" not in key + ] + if attentions: + update_unet_attention_ldm_to_diffusers( + attentions, + new_checkpoint, + unet_state_dict, + {"old": f"output_blocks.{i}.1", "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}"}, + ) + + if f"output_blocks.{i}.1.conv.weight" in unet_state_dict: + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.1.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.1.conv.bias" + ] + if f"output_blocks.{i}.2.conv.weight" in unet_state_dict: + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.2.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.2.conv.bias" + ] + + return new_checkpoint + + +def convert_controlnet_checkpoint( + checkpoint, + config, +): + # Some controlnet ckpt files are distributed independently from the rest of the + # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ + if "time_embed.0.weight" in checkpoint: + controlnet_state_dict = checkpoint + + else: + controlnet_state_dict = {} + keys = list(checkpoint.keys()) + controlnet_key = LDM_CONTROLNET_KEY + for key in keys: + if key.startswith(controlnet_key): + controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"]["layers"] + for diffusers_key, ldm_key in ldm_controlnet_keys.items(): + if ldm_key not in controlnet_state_dict: + continue + new_checkpoint[diffusers_key] = controlnet_state_dict[ldm_key] + + # Retrieves the keys for the input blocks only + num_input_blocks = len( + {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "input_blocks" in layer} + ) + input_blocks = { + layer_id: [key for key in controlnet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Down blocks + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + update_unet_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + controlnet_state_dict, + {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}, + ) + + if f"input_blocks.{i}.0.op.weight" in controlnet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + if attentions: + update_unet_attention_ldm_to_diffusers( + attentions, + new_checkpoint, + controlnet_state_dict, + {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}, + ) + + # controlnet down blocks + for i in range(num_input_blocks): + new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.pop(f"zero_convs.{i}.0.weight") + new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.pop(f"zero_convs.{i}.0.bias") + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len( + {".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "middle_block" in layer} + ) + middle_blocks = { + layer_id: [key for key in controlnet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + if middle_blocks: + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + update_unet_resnet_ldm_to_diffusers( + resnet_0, + new_checkpoint, + controlnet_state_dict, + mapping={"old": "middle_block.0", "new": "mid_block.resnets.0"}, + ) + update_unet_resnet_ldm_to_diffusers( + resnet_1, + new_checkpoint, + controlnet_state_dict, + mapping={"old": "middle_block.2", "new": "mid_block.resnets.1"}, + ) + update_unet_attention_ldm_to_diffusers( + attentions, + new_checkpoint, + controlnet_state_dict, + mapping={"old": "middle_block.1", "new": "mid_block.attentions.0"}, + ) + + # mid block + new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.pop("middle_block_out.0.weight") + new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.pop("middle_block_out.0.bias") + + # controlnet cond embedding blocks + cond_embedding_blocks = { + ".".join(layer.split(".")[:2]) + for layer in controlnet_state_dict + if "input_hint_block" in layer and ("input_hint_block.0" not in layer) and ("input_hint_block.14" not in layer) + } + num_cond_embedding_blocks = len(cond_embedding_blocks) + + for idx in range(1, num_cond_embedding_blocks + 1): + diffusers_idx = idx - 1 + cond_block_id = 2 * idx + + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.pop( + f"input_hint_block.{cond_block_id}.weight" + ) + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.pop( + f"input_hint_block.{cond_block_id}.bias" + ) + + return new_checkpoint + + +def create_diffusers_controlnet_model_from_ldm( + pipeline_class_name, original_config, checkpoint, upcast_attention=False, image_size=None +): + # import here to avoid circular imports + from ..models import ControlNetModel + + image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size) + + diffusers_config = create_controlnet_diffusers_config(original_config, image_size=image_size) + diffusers_config["upcast_attention"] = upcast_attention + + diffusers_format_controlnet_checkpoint = convert_controlnet_checkpoint(checkpoint, diffusers_config) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + controlnet = ControlNetModel(**diffusers_config) + + if is_accelerate_available(): + for param_name, param in diffusers_format_controlnet_checkpoint.items(): + set_module_tensor_to_device(controlnet, param_name, "cpu", value=param) + else: + controlnet.load_state_dict(diffusers_format_controlnet_checkpoint) + + return {"controlnet": controlnet} + + +def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping): + for ldm_key in keys: + diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut") + new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key) + + +def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping): + for ldm_key in keys: + diffusers_key = ( + ldm_key.replace(mapping["old"], mapping["new"]) + .replace("norm.weight", "group_norm.weight") + .replace("norm.bias", "group_norm.bias") + .replace("q.weight", "to_q.weight") + .replace("q.bias", "to_q.bias") + .replace("k.weight", "to_k.weight") + .replace("k.bias", "to_k.bias") + .replace("v.weight", "to_v.weight") + .replace("v.bias", "to_v.bias") + .replace("proj_out.weight", "to_out.0.weight") + .replace("proj_out.bias", "to_out.0.bias") + ) + new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key) + + # proj_attn.weight has to be converted from conv 1D to linear + shape = new_checkpoint[diffusers_key].shape + + if len(shape) == 3: + new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0] + elif len(shape) == 4: + new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0] + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + # remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys + vae_state_dict = {} + keys = list(checkpoint.keys()) + vae_key = LDM_VAE_KEY if any(k.startswith(LDM_VAE_KEY) for k in keys) else "" + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + vae_diffusers_ldm_map = DIFFUSERS_TO_LDM_MAPPING["vae"] + for diffusers_key, ldm_key in vae_diffusers_ldm_map.items(): + if ldm_key not in vae_state_dict: + continue + new_checkpoint[diffusers_key] = vae_state_dict[ldm_key] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len(config["down_block_types"]) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + update_vae_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + vae_state_dict, + mapping={"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}, + ) + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + update_vae_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + vae_state_dict, + mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}, + ) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + update_vae_attentions_ldm_to_diffusers( + mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"} + ) + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len(config["up_block_types"]) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + update_vae_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + vae_state_dict, + mapping={"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}, + ) + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + update_vae_resnet_ldm_to_diffusers( + resnets, + new_checkpoint, + vae_state_dict, + mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}, + ) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + update_vae_attentions_ldm_to_diffusers( + mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"} + ) + conv_attn_to_linear(new_checkpoint) + + return new_checkpoint + + +def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_files_only=False): + try: + config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: 'openai/clip-vit-large-patch14'." + ) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + text_model = CLIPTextModel(config) + + keys = list(checkpoint.keys()) + text_model_dict = {} + + remove_prefixes = LDM_CLIP_PREFIX_TO_REMOVE + + for key in keys: + for prefix in remove_prefixes: + if key.startswith(prefix): + diffusers_key = key.replace(prefix, "") + text_model_dict[diffusers_key] = checkpoint[key] + + if is_accelerate_available(): + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device(text_model, param_name, "cpu", value=param) + else: + if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)): + text_model_dict.pop("text_model.embeddings.position_ids", None) + + text_model.load_state_dict(text_model_dict) + + return text_model + + +def create_text_encoder_from_open_clip_checkpoint( + config_name, + checkpoint, + prefix="cond_stage_model.model.", + has_projection=False, + local_files_only=False, + **config_kwargs, +): + try: + config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs, local_files_only=local_files_only) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: '{config_name}'." + ) + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config) + + text_model_dict = {} + text_proj_key = prefix + "text_projection" + text_proj_dim = ( + int(checkpoint[text_proj_key].shape[0]) if text_proj_key in checkpoint else LDM_OPEN_CLIP_TEXT_PROJECTION_DIM + ) + text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") + + keys = list(checkpoint.keys()) + keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE + + openclip_diffusers_ldm_map = DIFFUSERS_TO_LDM_MAPPING["openclip"]["layers"] + for diffusers_key, ldm_key in openclip_diffusers_ldm_map.items(): + ldm_key = prefix + ldm_key + if ldm_key not in checkpoint: + continue + if ldm_key in keys_to_ignore: + continue + if ldm_key.endswith("text_projection"): + text_model_dict[diffusers_key] = checkpoint[ldm_key].T.contiguous() + else: + text_model_dict[diffusers_key] = checkpoint[ldm_key] + + for key in keys: + if key in keys_to_ignore: + continue + + if not key.startswith(prefix + "transformer."): + continue + + diffusers_key = key.replace(prefix + "transformer.", "") + transformer_diffusers_to_ldm_map = DIFFUSERS_TO_LDM_MAPPING["openclip"]["transformer"] + for new_key, old_key in transformer_diffusers_to_ldm_map.items(): + diffusers_key = ( + diffusers_key.replace(old_key, new_key).replace(".in_proj_weight", "").replace(".in_proj_bias", "") + ) + + if key.endswith(".in_proj_weight"): + weight_value = checkpoint[key] + + text_model_dict[diffusers_key + ".q_proj.weight"] = weight_value[:text_proj_dim, :] + text_model_dict[diffusers_key + ".k_proj.weight"] = weight_value[text_proj_dim : text_proj_dim * 2, :] + text_model_dict[diffusers_key + ".v_proj.weight"] = weight_value[text_proj_dim * 2 :, :] + + elif key.endswith(".in_proj_bias"): + weight_value = checkpoint[key] + text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim] + text_model_dict[diffusers_key + ".k_proj.bias"] = weight_value[text_proj_dim : text_proj_dim * 2] + text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2 :] + + else: + text_model_dict[diffusers_key] = checkpoint[key] + + if is_accelerate_available(): + for param_name, param in text_model_dict.items(): + set_module_tensor_to_device(text_model, param_name, "cpu", value=param) + + else: + if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)): + text_model_dict.pop("text_model.embeddings.position_ids", None) + + text_model.load_state_dict(text_model_dict) + + return text_model + + +def create_diffusers_unet_model_from_ldm( + pipeline_class_name, + original_config, + checkpoint, + num_in_channels=None, + upcast_attention=False, + extract_ema=False, + image_size=None, +): + from ..models import UNet2DConditionModel + + if num_in_channels is None: + if pipeline_class_name in [ + "StableDiffusionInpaintPipeline", + "StableDiffusionXLInpaintPipeline", + "StableDiffusionXLControlNetInpaintPipeline", + ]: + num_in_channels = 9 + + elif pipeline_class_name == "StableDiffusionUpscalePipeline": + num_in_channels = 7 + + else: + num_in_channels = 4 + + image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size) + unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + unet_config["in_channels"] = num_in_channels + unet_config["upcast_attention"] = upcast_attention + + diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config, extract_ema=extract_ema) + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + unet = UNet2DConditionModel(**unet_config) + + if is_accelerate_available(): + for param_name, param in diffusers_format_unet_checkpoint.items(): + set_module_tensor_to_device(unet, param_name, "cpu", value=param) + else: + unet.load_state_dict(diffusers_format_unet_checkpoint) + + return {"unet": unet} + + +def create_diffusers_vae_model_from_ldm( + pipeline_class_name, original_config, checkpoint, image_size=None, scaling_factor=0.18125 +): + # import here to avoid circular imports + from ..models import AutoencoderKL + + image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size) + + vae_config = create_vae_diffusers_config(original_config, image_size=image_size, scaling_factor=scaling_factor) + diffusers_format_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + ctx = init_empty_weights if is_accelerate_available() else nullcontext + + with ctx(): + vae = AutoencoderKL(**vae_config) + + if is_accelerate_available(): + for param_name, param in diffusers_format_vae_checkpoint.items(): + set_module_tensor_to_device(vae, param_name, "cpu", value=param) + else: + vae.load_state_dict(diffusers_format_vae_checkpoint) + + return {"vae": vae} + + +def create_text_encoders_and_tokenizers_from_ldm( + original_config, + checkpoint, + model_type=None, + local_files_only=False, +): + model_type = infer_model_type(original_config, model_type=model_type) + + if model_type == "FrozenOpenCLIPEmbedder": + config_name = "stabilityai/stable-diffusion-2" + config_kwargs = {"subfolder": "text_encoder"} + + try: + text_encoder = create_text_encoder_from_open_clip_checkpoint( + config_name, checkpoint, local_files_only=local_files_only, **config_kwargs + ) + tokenizer = CLIPTokenizer.from_pretrained( + config_name, subfolder="tokenizer", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder in the following path: '{config_name}'." + ) + else: + return {"text_encoder": text_encoder, "tokenizer": tokenizer} + + elif model_type == "FrozenCLIPEmbedder": + try: + config_name = "openai/clip-vit-large-patch14" + text_encoder = create_text_encoder_from_ldm_clip_checkpoint( + config_name, checkpoint, local_files_only=local_files_only + ) + tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only) + + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: '{config_name}'." + ) + else: + return {"text_encoder": text_encoder, "tokenizer": tokenizer} + + elif model_type == "SDXL-Refiner": + config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + config_kwargs = {"projection_dim": 1280} + prefix = "conditioner.embedders.0.model." + + try: + tokenizer_2 = CLIPTokenizer.from_pretrained(config_name, pad_token="!", local_files_only=local_files_only) + text_encoder_2 = create_text_encoder_from_open_clip_checkpoint( + config_name, + checkpoint, + prefix=prefix, + has_projection=True, + local_files_only=local_files_only, + **config_kwargs, + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder_2 and tokenizer_2 in the following path: {config_name} with `pad_token` set to '!'." + ) + + else: + return { + "text_encoder": None, + "tokenizer": None, + "tokenizer_2": tokenizer_2, + "text_encoder_2": text_encoder_2, + } + + elif model_type == "SDXL": + try: + config_name = "openai/clip-vit-large-patch14" + tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only) + text_encoder = create_text_encoder_from_ldm_clip_checkpoint( + config_name, checkpoint, local_files_only=local_files_only + ) + + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder and tokenizer in the following path: 'openai/clip-vit-large-patch14'." + ) + + try: + config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + config_kwargs = {"projection_dim": 1280} + prefix = "conditioner.embedders.1.model." + tokenizer_2 = CLIPTokenizer.from_pretrained(config_name, pad_token="!", local_files_only=local_files_only) + text_encoder_2 = create_text_encoder_from_open_clip_checkpoint( + config_name, + checkpoint, + prefix=prefix, + has_projection=True, + local_files_only=local_files_only, + **config_kwargs, + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder_2 and tokenizer_2 in the following path: {config_name} with `pad_token` set to '!'." + ) + + return { + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "tokenizer_2": tokenizer_2, + "text_encoder_2": text_encoder_2, + } + + return + + +def create_scheduler_from_ldm( + pipeline_class_name, + original_config, + checkpoint, + prediction_type=None, + scheduler_type="ddim", + model_type=None, +): + scheduler_config = get_default_scheduler_config() + model_type = infer_model_type(original_config, model_type=model_type) + + global_step = checkpoint["global_step"] if "global_step" in checkpoint else None + + num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", None) or 1000 + scheduler_config["num_train_timesteps"] = num_train_timesteps + + if ( + "parameterization" in original_config["model"]["params"] + and original_config["model"]["params"]["parameterization"] == "v" + ): + if prediction_type is None: + # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` + # as it relies on a brittle global step parameter here + prediction_type = "epsilon" if global_step == 875000 else "v_prediction" + + else: + prediction_type = prediction_type or "epsilon" + + scheduler_config["prediction_type"] = prediction_type + + if model_type in ["SDXL", "SDXL-Refiner"]: + scheduler_type = "euler" + + else: + beta_start = original_config["model"]["params"].get("linear_start", 0.02) + beta_end = original_config["model"]["params"].get("linear_end", 0.085) + scheduler_config["beta_start"] = beta_start + scheduler_config["beta_end"] = beta_end + scheduler_config["beta_schedule"] = "scaled_linear" + scheduler_config["clip_sample"] = False + scheduler_config["set_alpha_to_one"] = False + + if scheduler_type == "pndm": + scheduler_config["skip_prk_steps"] = True + scheduler = PNDMScheduler.from_config(scheduler_config) + + elif scheduler_type == "lms": + scheduler = LMSDiscreteScheduler.from_config(scheduler_config) + + elif scheduler_type == "heun": + scheduler = HeunDiscreteScheduler.from_config(scheduler_config) + + elif scheduler_type == "euler": + scheduler = EulerDiscreteScheduler.from_config(scheduler_config) + + elif scheduler_type == "euler-ancestral": + scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config) + + elif scheduler_type == "dpm": + scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config) + + elif scheduler_type == "ddim": + scheduler = DDIMScheduler.from_config(scheduler_config) + + else: + raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") + + if pipeline_class_name == "StableDiffusionUpscalePipeline": + scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler") + low_res_scheduler = DDPMScheduler.from_pretrained( + "stabilityai/stable-diffusion-x4-upscaler", subfolder="low_res_scheduler" + ) + + return { + "scheduler": scheduler, + "low_res_scheduler": low_res_scheduler, + } + + return {"scheduler": scheduler} diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py index ab4b16a193..68d5a31e43 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py @@ -17,7 +17,6 @@ import torch import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import FromOriginalVAEMixin from ...utils import is_torch_version from ...utils.accelerate_utils import apply_forward_hook from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor @@ -162,7 +161,7 @@ class TemporalDecoder(nn.Module): return sample -class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin): +class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 65dc051390..60ddc999e9 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -19,7 +19,7 @@ from torch import nn from torch.nn import functional as F from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import FromOriginalControlnetMixin +from ..loaders import FromOriginalControlNetMixin from ..utils import BaseOutput, logging from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -108,7 +108,7 @@ class ControlNetConditioningEmbedding(nn.Module): return embedding -class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): +class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin): """ A ControlNet model. diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 445c3ca71c..90a700d944 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -32,6 +32,7 @@ from .. import __version__ from ..utils import ( CONFIG_NAME, FLAX_WEIGHTS_NAME, + SAFETENSORS_FILE_EXTENSION, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, _add_variant, @@ -102,10 +103,11 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ Reads a checkpoint file, returning properly formatted errors if they arise. """ try: - if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant): - return torch.load(checkpoint_file, map_location="cpu") - else: + file_extension = os.path.basename(checkpoint_file).split(".")[-1] + if file_extension == SAFETENSORS_FILE_EXTENSION: return safetensors.torch.load_file(checkpoint_file, device="cpu") + else: + return torch.load(checkpoint_file, map_location="cpu") except Exception as e: try: with open(checkpoint_file) as f: diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index de5dea679e..4d0bc8b13a 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -351,7 +351,7 @@ def get_class_obj_and_candidates( def _get_pipeline_class( class_obj, - config, + config=None, load_connected_pipeline=False, custom_pipeline=None, repo_id=None, @@ -389,7 +389,12 @@ def _get_pipeline_class( return class_obj diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) - class_name = config["_class_name"] + class_name = class_name or config["_class_name"] + if not class_name: + raise ValueError( + "The class name could not be found in the configuration file. Please make sure to pass the correct `class_name`." + ) + class_name = class_name[4:] if class_name.startswith("Flax") else class_name pipeline_cls = getattr(diffusers_module, class_name) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 468476e0c7..667f1fe5e2 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -28,6 +28,7 @@ from .constants import ( MIN_PEFT_VERSION, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, + SAFETENSORS_FILE_EXTENSION, SAFETENSORS_WEIGHTS_NAME, USE_PEFT_BACKEND, WEIGHTS_NAME, diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 8850da073e..a397e8cf86 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -31,6 +31,7 @@ WEIGHTS_NAME = "diffusion_pytorch_model.bin" FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" ONNX_WEIGHTS_NAME = "model.onnx" SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" +SAFETENSORS_FILE_EXTENSION = "safetensors" ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb" HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co") DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index d762f015a7..6f311b957a 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -244,15 +244,15 @@ def _get_model_file( pretrained_model_name_or_path: Union[str, Path], *, weights_name: str, - subfolder: Optional[str], - cache_dir: Optional[str], - force_download: bool, - proxies: Optional[Dict], - resume_download: bool, - local_files_only: bool, - token: Optional[str], - user_agent: Union[Dict, str, None], - revision: Optional[str], + subfolder: Optional[str] = None, + cache_dir: Optional[str] = None, + force_download: bool = False, + proxies: Optional[Dict] = None, + resume_download: bool = False, + local_files_only: bool = False, + token: Optional[str] = None, + user_agent: Optional[Union[Dict, str]] = None, + revision: Optional[str] = None, commit_hash: Optional[str] = None, ): pretrained_model_name_or_path = str(pretrained_model_name_or_path) diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py index c034a9b68b..05f3ade508 100644 --- a/tests/pipelines/controlnet/test_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -37,6 +37,7 @@ from diffusers.utils.testing_utils import ( enable_full_determinism, load_image, load_numpy, + numpy_cosine_similarity_distance, require_python39_or_higher, require_torch_2, require_torch_gpu, @@ -1022,39 +1023,49 @@ class ControlNetPipelineSlowTests(unittest.TestCase): def test_load_local(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny") - pipe_1 = StableDiffusionControlNetPipeline.from_pretrained( + pipe = StableDiffusionControlNetPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet ) + pipe.unet.set_default_attn_processor() + pipe.enable_model_cpu_offload() controlnet = ControlNetModel.from_single_file( "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" ) - pipe_2 = StableDiffusionControlNetPipeline.from_single_file( + pipe_sf = StableDiffusionControlNetPipeline.from_single_file( "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors", safety_checker=None, controlnet=controlnet, + scheduler_type="pndm", ) - pipes = [pipe_1, pipe_2] - images = [] + pipe_sf.unet.set_default_attn_processor() + pipe_sf.enable_model_cpu_offload() - for pipe in pipes: - pipe.enable_model_cpu_offload() - pipe.set_progress_bar_config(disable=None) + control_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ).resize((512, 512)) + prompt = "bird" - generator = torch.Generator(device="cpu").manual_seed(0) - prompt = "bird" - image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" - ) + generator = torch.Generator(device="cpu").manual_seed(0) + output = pipe( + prompt, + image=control_image, + generator=generator, + output_type="np", + num_inference_steps=3, + ).images[0] - output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3) - images.append(output.images[0]) + generator = torch.Generator(device="cpu").manual_seed(0) + output_sf = pipe_sf( + prompt, + image=control_image, + generator=generator, + output_type="np", + num_inference_steps=3, + ).images[0] - del pipe - gc.collect() - torch.cuda.empty_cache() - - assert np.abs(images[0] - images[1]).max() < 1e-3 + max_diff = numpy_cosine_similarity_distance(output_sf.flatten(), output.flatten()) + assert max_diff < 1e-3 @slow diff --git a/tests/pipelines/controlnet/test_controlnet_img2img.py b/tests/pipelines/controlnet/test_controlnet_img2img.py index b4b67e6476..939eb34ef0 100644 --- a/tests/pipelines/controlnet/test_controlnet_img2img.py +++ b/tests/pipelines/controlnet/test_controlnet_img2img.py @@ -39,6 +39,7 @@ from diffusers.utils.testing_utils import ( enable_full_determinism, floats_tensor, load_numpy, + numpy_cosine_similarity_distance, require_torch_gpu, slow, torch_device, @@ -421,46 +422,53 @@ class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase): def test_load_local(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny") - pipe_1 = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( + pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet ) + pipe.unet.set_default_attn_processor() + pipe.enable_model_cpu_offload() controlnet = ControlNetModel.from_single_file( "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" ) - pipe_2 = StableDiffusionControlNetImg2ImgPipeline.from_single_file( + pipe_sf = StableDiffusionControlNetImg2ImgPipeline.from_single_file( "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors", safety_checker=None, controlnet=controlnet, + scheduler_type="pndm", ) + pipe_sf.unet.set_default_attn_processor() + pipe_sf.enable_model_cpu_offload() + control_image = load_image( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" ).resize((512, 512)) image = load_image( "https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png" ).resize((512, 512)) + prompt = "bird" - pipes = [pipe_1, pipe_2] - images = [] - for pipe in pipes: - pipe.enable_model_cpu_offload() - pipe.set_progress_bar_config(disable=None) + generator = torch.Generator(device="cpu").manual_seed(0) + output = pipe( + prompt, + image=image, + control_image=control_image, + strength=0.9, + generator=generator, + output_type="np", + num_inference_steps=3, + ).images[0] - generator = torch.Generator(device="cpu").manual_seed(0) - prompt = "bird" - output = pipe( - prompt, - image=image, - control_image=control_image, - strength=0.9, - generator=generator, - output_type="np", - num_inference_steps=3, - ) - images.append(output.images[0]) + generator = torch.Generator(device="cpu").manual_seed(0) + output_sf = pipe_sf( + prompt, + image=image, + control_image=control_image, + strength=0.9, + generator=generator, + output_type="np", + num_inference_steps=3, + ).images[0] - del pipe - gc.collect() - torch.cuda.empty_cache() - - assert np.abs(images[0] - images[1]).max() < 1e-3 + max_diff = numpy_cosine_similarity_distance(output_sf.flatten(), output.flatten()) + assert max_diff < 1e-3 diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py index 7c3371c197..7db336df94 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py @@ -569,6 +569,7 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase): "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors", safety_checker=None, controlnet=controlnet, + scheduler_type="pndm", ) control_image = load_image( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" @@ -605,4 +606,5 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase): gc.collect() torch.cuda.empty_cache() - assert np.abs(images[0] - images[1]).max() < 1e-3 + max_diff = numpy_cosine_similarity_distance(images[0].flatten(), images[1].flatten()) + assert max_diff < 1e-3 diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index 88cf254ff6..7ccf724361 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -31,7 +31,14 @@ from diffusers import ( from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel from diffusers.utils.import_utils import is_xformers_available -from diffusers.utils.testing_utils import enable_full_determinism, load_image, require_torch_gpu, slow, torch_device +from diffusers.utils.testing_utils import ( + enable_full_determinism, + load_image, + numpy_cosine_similarity_distance, + require_torch_gpu, + slow, + torch_device, +) from diffusers.utils.torch_utils import randn_tensor from ..pipeline_params import ( @@ -819,6 +826,41 @@ class ControlNetSDXLPipelineSlowTests(unittest.TestCase): expected_image = np.array([0.4399, 0.5112, 0.5478, 0.4314, 0.472, 0.4823, 0.4647, 0.4957, 0.4853]) assert np.allclose(original_image, expected_image, atol=1e-04) + def test_download_ckpt_diff_format_is_same(self): + controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16) + single_file_url = ( + "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors" + ) + pipe_single_file = StableDiffusionXLControlNetPipeline.from_single_file( + single_file_url, controlnet=controlnet, torch_dtype=torch.float16 + ) + pipe_single_file.unet.set_default_attn_processor() + pipe_single_file.enable_model_cpu_offload() + pipe_single_file.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "Stormtrooper's lecture" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png" + ) + single_file_images = pipe_single_file( + prompt, image=image, generator=generator, output_type="np", num_inference_steps=2 + ).images + + generator = torch.Generator(device="cpu").manual_seed(0) + pipe = StableDiffusionXLControlNetPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 + ) + pipe.unet.set_default_attn_processor() + pipe.enable_model_cpu_offload() + images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=2).images + + assert images[0].shape == (512, 512, 3) + assert single_file_images[0].shape == (512, 512, 3) + + max_diff = numpy_cosine_similarity_distance(images[0].flatten(), single_file_images[0].flatten()) + assert max_diff < 5e-2 + class StableDiffusionSSD1BControlNetPipelineFastTests(StableDiffusionXLControlNetPipelineFastTests): def test_controlnet_sdxl_guess(self): diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 4ea22574cc..bfdcec09f2 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -1262,13 +1262,13 @@ class StableDiffusionPipelineCkptTests(unittest.TestCase): def test_download_ckpt_diff_format_is_same(self): ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt" - pipe = StableDiffusionPipeline.from_single_file(ckpt_path) - pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) - pipe.unet.set_attn_processor(AttnProcessor()) - pipe.to("cuda") + sf_pipe = StableDiffusionPipeline.from_single_file(ckpt_path) + sf_pipe.scheduler = DDIMScheduler.from_config(sf_pipe.scheduler.config) + sf_pipe.unet.set_attn_processor(AttnProcessor()) + sf_pipe.to("cuda") generator = torch.Generator(device="cpu").manual_seed(0) - image_ckpt = pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0] + image_single_file = sf_pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0] pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) @@ -1278,7 +1278,7 @@ class StableDiffusionPipelineCkptTests(unittest.TestCase): generator = torch.Generator(device="cpu").manual_seed(0) image = pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0] - max_diff = numpy_cosine_similarity_distance(image.flatten(), image_ckpt.flatten()) + max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten()) assert max_diff < 1e-3 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index fe664b21e2..7ec6964b06 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -43,6 +43,7 @@ from diffusers.utils.testing_utils import ( load_image, load_numpy, nightly, + numpy_cosine_similarity_distance, require_python39_or_higher, require_torch_2, require_torch_gpu, @@ -771,7 +772,9 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase): inputs["num_inference_steps"] = 5 image = pipe(**inputs).images[0] - assert np.max(np.abs(image - image_ckpt)) < 5e-4 + max_diff = numpy_cosine_similarity_distance(image.flatten(), image_ckpt.flatten()) + + assert max_diff < 1e-4 @slow diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 80bff3663a..d5ad5ee0d7 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -14,6 +14,7 @@ # limitations under the License. import copy +import gc import tempfile import unittest @@ -1024,6 +1025,11 @@ class StableDiffusionXLPipelineFastTests( @slow class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + def test_stable_diffusion_lcm(self): torch.manual_seed(0) unet = UNet2DConditionModel.from_pretrained( @@ -1049,3 +1055,30 @@ class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase): max_diff = numpy_cosine_similarity_distance(image.flatten(), expected_image.flatten()) assert max_diff < 1e-2 + + def test_download_ckpt_diff_format_is_same(self): + ckpt_path = ( + "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors" + ) + + pipe = StableDiffusionXLPipeline.from_single_file(ckpt_path, torch_dtype=torch.float16) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.unet.set_default_attn_processor() + pipe.enable_model_cpu_offload() + + generator = torch.Generator(device="cpu").manual_seed(0) + image_ckpt = pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0] + + pipe = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.unet.set_default_attn_processor() + pipe.enable_model_cpu_offload() + + generator = torch.Generator(device="cpu").manual_seed(0) + image = pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0] + + max_diff = numpy_cosine_similarity_distance(image.flatten(), image_ckpt.flatten()) + + assert max_diff < 6e-3 diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py index 2296e092e7..1513525df3 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py @@ -699,3 +699,40 @@ class AdapterSDXLPipelineSlowTests(unittest.TestCase): image_slice = images[0, -3:, -3:, -1].flatten() expected_slice = np.array([0.4284, 0.4337, 0.4319, 0.4255, 0.4329, 0.4280, 0.4338, 0.4420, 0.4226]) assert numpy_cosine_similarity_distance(image_slice, expected_slice) < 1e-4 + + def test_download_ckpt_diff_format_is_same(self): + ckpt_path = ( + "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors" + ) + adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16) + prompt = "toy" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png" + ) + pipe_single_file = StableDiffusionXLAdapterPipeline.from_single_file( + ckpt_path, + adapter=adapter, + torch_dtype=torch.float16, + ) + pipe_single_file.enable_model_cpu_offload() + pipe_single_file.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + images_single_file = pipe_single_file( + prompt, image=image, generator=generator, output_type="np", num_inference_steps=3 + ).images + + generator = torch.Generator(device="cpu").manual_seed(0) + pipe = StableDiffusionXLAdapterPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + adapter=adapter, + torch_dtype=torch.float16, + ) + pipe.enable_model_cpu_offload() + images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images + + assert images_single_file[0].shape == (768, 512, 3) + assert images[0].shape == (768, 512, 3) + + max_diff = numpy_cosine_similarity_distance(images[0].flatten(), images_single_file[0].flatten()) + assert max_diff < 5e-3 diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py index 0a7d4d0de4..e505630cf6 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import random import unittest @@ -31,15 +32,19 @@ from transformers import ( from diffusers import ( AutoencoderKL, AutoencoderTiny, + DDIMScheduler, EulerDiscreteScheduler, LCMScheduler, StableDiffusionXLImg2ImgPipeline, UNet2DConditionModel, ) +from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, floats_tensor, + numpy_cosine_similarity_distance, require_torch_gpu, + slow, torch_device, ) @@ -763,3 +768,44 @@ class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests( def test_save_load_optional_components(self): self._test_save_load_optional_components() + + +@slow +class StableDiffusionXLImg2ImgIntegrationTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_download_ckpt_diff_format_is_same(self): + ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/main/sd_xl_refiner_1.0.safetensors" + init_image = load_image( + "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main" + "/stable_diffusion_img2img/sketch-mountains-input.png" + ) + + pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16 + ) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.unet.set_default_attn_processor() + pipe.enable_model_cpu_offload() + + generator = torch.Generator(device="cpu").manual_seed(0) + image = pipe( + prompt="mountains", image=init_image, num_inference_steps=5, generator=generator, output_type="np" + ).images[0] + + pipe_single_file = StableDiffusionXLImg2ImgPipeline.from_single_file(ckpt_path, torch_dtype=torch.float16) + pipe_single_file.scheduler = DDIMScheduler.from_config(pipe_single_file.scheduler.config) + pipe_single_file.unet.set_default_attn_processor() + pipe_single_file.enable_model_cpu_offload() + + generator = torch.Generator(device="cpu").manual_seed(0) + image_single_file = pipe_single_file( + prompt="mountains", image=init_image, num_inference_steps=5, generator=generator, output_type="np" + ).images[0] + + max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten()) + + assert max_diff < 5e-2