From 6b1abba18dce2eb169b6363ea5f626e7cd87cf21 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 19 Jul 2023 14:50:27 +0200 Subject: [PATCH] Add controlnet and vae from single file (#4084) * Add controlnet from single file * Updates * make style * finish * Apply suggestions from code review Co-authored-by: Sayak Paul --------- Co-authored-by: Sayak Paul --- docs/source/en/api/loaders.mdx | 8 + docs/source/en/api/models/autoencoderkl.mdx | 14 +- docs/source/en/api/models/controlnet.mdx | 17 +- src/diffusers/loaders.py | 363 +++++++++++++++++- src/diffusers/models/autoencoder_kl.py | 3 +- src/diffusers/models/controlnet.py | 3 +- .../controlnet/pipeline_controlnet.py | 6 +- .../controlnet/pipeline_controlnet_img2img.py | 6 +- .../controlnet/pipeline_controlnet_inpaint.py | 6 +- .../stable_diffusion/convert_from_ckpt.py | 25 +- tests/models/test_models_vae.py | 21 +- tests/pipelines/controlnet/test_controlnet.py | 36 ++ .../controlnet/test_controlnet_img2img.py | 46 +++ .../controlnet/test_controlnet_inpaint.py | 51 +++ 14 files changed, 576 insertions(+), 29 deletions(-) diff --git a/docs/source/en/api/loaders.mdx b/docs/source/en/api/loaders.mdx index 57891d23de..98aaea0060 100644 --- a/docs/source/en/api/loaders.mdx +++ b/docs/source/en/api/loaders.mdx @@ -35,3 +35,11 @@ Adapters (textual inversion, LoRA, hypernetworks) allow you to modify a diffusio ## FromSingleFileMixin [[autodoc]] loaders.FromSingleFileMixin + +## FromOriginalControlnetMixin + +[[autodoc]] loaders.FromOriginalControlnetMixin + +## FromOriginalVAEMixin + +[[autodoc]] loaders.FromOriginalVAEMixin diff --git a/docs/source/en/api/models/autoencoderkl.mdx b/docs/source/en/api/models/autoencoderkl.mdx index 542fc27cd5..bc709c422d 100644 --- a/docs/source/en/api/models/autoencoderkl.mdx +++ b/docs/source/en/api/models/autoencoderkl.mdx @@ -6,6 +6,18 @@ The abstract from the paper is: *How can we perform efficient inference and learning in directed probabilistic models, in the presence of continuous latent variables with intractable posterior distributions, and large datasets? We introduce a stochastic variational inference and learning algorithm that scales to large datasets and, under some mild differentiability conditions, even works in the intractable case. Our contributions are two-fold. First, we show that a reparameterization of the variational lower bound yields a lower bound estimator that can be straightforwardly optimized using standard stochastic gradient methods. Second, we show that for i.i.d. datasets with continuous latent variables per datapoint, posterior inference can be made especially efficient by fitting an approximate inference model (also called a recognition model) to the intractable posterior using the proposed lower bound estimator. Theoretical advantages are reflected in experimental results.* +## Loading from the original format + +By default the [`AutoencoderKL`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded +from the original format using [`FromOriginalVAEMixin.from_single_file`] as follows: + +```py +from diffusers import AutoencoderKL + +url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file +model = AutoencoderKL.from_single_file(url) +``` + ## AutoencoderKL [[autodoc]] AutoencoderKL @@ -28,4 +40,4 @@ The abstract from the paper is: ## FlaxDecoderOutput -[[autodoc]] models.vae_flax.FlaxDecoderOutput \ No newline at end of file +[[autodoc]] models.vae_flax.FlaxDecoderOutput diff --git a/docs/source/en/api/models/controlnet.mdx b/docs/source/en/api/models/controlnet.mdx index ae2d06edbb..e02adde8a1 100644 --- a/docs/source/en/api/models/controlnet.mdx +++ b/docs/source/en/api/models/controlnet.mdx @@ -6,6 +6,21 @@ The abstract from the paper is: *We present a neural network structure, ControlNet, to control pretrained large diffusion models to support additional input conditions. The ControlNet learns task-specific conditions in an end-to-end way, and the learning is robust even when the training dataset is small (< 50k). Moreover, training a ControlNet is as fast as fine-tuning a diffusion model, and the model can be trained on a personal devices. Alternatively, if powerful computation clusters are available, the model can scale to large amounts (millions to billions) of data. We report that large diffusion models like Stable Diffusion can be augmented with ControlNets to enable conditional inputs like edge maps, segmentation maps, keypoints, etc. This may enrich the methods to control large diffusion models and further facilitate related applications.* +## Loading from the original format + +By default the [`ControlNetModel`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded +from the original format using [`FromOriginalControlnetMixin.from_single_file`] as follows: + +```py +from diffusers import StableDiffusionControlnetPipeline, ControlNetModel + +url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path +controlnet = ControlNetModel.from_single_file(url) + +url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path +pipe = StableDiffusionControlnetPipeline.from_single_file(url, controlnet=controlnet) +``` + ## ControlNetModel [[autodoc]] ControlNetModel @@ -20,4 +35,4 @@ The abstract from the paper is: ## FlaxControlNetOutput -[[autodoc]] models.controlnet_flax.FlaxControlNetOutput \ No newline at end of file +[[autodoc]] models.controlnet_flax.FlaxControlNetOutput diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index db285f30ef..8ce5989b5f 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -14,9 +14,12 @@ import os import warnings from collections import defaultdict +from contextlib import nullcontext +from io import BytesIO from pathlib import Path from typing import Callable, Dict, List, Optional, Union +import requests import torch import torch.nn.functional as F from huggingface_hub import hf_hub_download @@ -42,10 +45,13 @@ from .utils import ( HF_HUB_OFFLINE, _get_model_file, deprecate, + is_accelerate_available, + is_omegaconf_available, is_safetensors_available, is_transformers_available, logging, ) +from .utils.import_utils import BACKENDS_MAPPING if is_safetensors_available(): @@ -54,6 +60,9 @@ if is_safetensors_available(): if is_transformers_available(): from transformers import CLIPTextModel, PreTrainedModel, PreTrainedTokenizer +if is_accelerate_available(): + from accelerate import init_empty_weights + from accelerate.utils import set_module_tensor_to_device logger = logging.get_logger(__name__) @@ -1319,8 +1328,8 @@ class FromSingleFileMixin: @classmethod def from_single_file(cls, pretrained_model_link_or_path, **kwargs): r""" - Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` format. The pipeline - is set in evaluation mode (`model.eval()`) by default. + Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors` + format. The pipeline is set in evaluation mode (`model.eval()`) by default. Parameters: pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): @@ -1430,6 +1439,7 @@ class FromSingleFileMixin: load_safety_checker = kwargs.pop("load_safety_checker", True) prediction_type = kwargs.pop("prediction_type", None) text_encoder = kwargs.pop("text_encoder", None) + controlnet = kwargs.pop("controlnet", None) tokenizer = kwargs.pop("tokenizer", None) torch_dtype = kwargs.pop("torch_dtype", None) @@ -1446,11 +1456,18 @@ class FromSingleFileMixin: # TODO: For now we only support stable diffusion stable_unclip = None model_type = None - controlnet = False - if pipeline_name == "StableDiffusionControlNetPipeline": + if pipeline_name in [ + "StableDiffusionControlNetPipeline", + "StableDiffusionControlNetImg2ImgPipeline", + "StableDiffusionControlNetInpaintPipeline", + ]: + from .models.controlnet import ControlNetModel + from .pipelines.controlnet.multicontrolnet import MultiControlNetModel + # Model type will be inferred from the checkpoint. - controlnet = True + if not isinstance(controlnet, (ControlNetModel, MultiControlNetModel)): + raise ValueError("ControlNet needs to be passed if loading from ControlNet pipeline.") elif "StableDiffusion" in pipeline_name: # Model type will be inferred from the checkpoint. pass @@ -1519,3 +1536,339 @@ class FromSingleFileMixin: pipe.to(torch_dtype=torch_dtype) return pipe + + +class FromOriginalVAEMixin: + @classmethod + def from_single_file(cls, pretrained_model_link_or_path, **kwargs): + r""" + Instantiate a [`AutoencoderKL`] from pretrained controlnet weights saved in the original `.ckpt` or + `.safetensors` format. The pipeline is format. The pipeline is set in evaluation mode (`model.eval()`) by + default. + + Parameters: + pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + - A link to the `.ckpt` file (for example + `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. + - A path to a *file* containing all pipeline weights. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to True, the model + won't be downloaded from the Hub. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + image_size (`int`, *optional*, defaults to 512): + The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable + Diffusion v2 base model. Use 768 for Stable Diffusion v2. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + upcast_attention (`bool`, *optional*, defaults to `None`): + Whether the attention computation should always be upcasted. + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z + = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution + Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (for example the pipeline components of the + specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` + method. See example below for more information. + + + + Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you want to load + a VAE that does accompany a stable diffusion model of v2 or higher or SDXL. + + + + Examples: + + ```py + from diffusers import AutoencoderKL + + url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file + model = AutoencoderKL.from_single_file(url) + ``` + """ + if not is_omegaconf_available(): + raise ValueError(BACKENDS_MAPPING["omegaconf"][1]) + + from omegaconf import OmegaConf + + from .models import AutoencoderKL + + # import here to avoid circular dependency + from .pipelines.stable_diffusion.convert_from_ckpt import ( + convert_ldm_vae_checkpoint, + create_vae_diffusers_config, + ) + + config_file = kwargs.pop("config_file", None) + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + image_size = kwargs.pop("image_size", None) + scaling_factor = kwargs.pop("scaling_factor", None) + kwargs.pop("upcast_attention", None) + + torch_dtype = kwargs.pop("torch_dtype", None) + + use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) + + file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] + from_safetensors = file_extension == "safetensors" + + if from_safetensors and use_safetensors is False: + raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.") + + # remove huggingface url + for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]: + if pretrained_model_link_or_path.startswith(prefix): + pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :] + + # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained + ckpt_path = Path(pretrained_model_link_or_path) + if not ckpt_path.is_file(): + # get repo_id and (potentially nested) file path of ckpt in repo + repo_id = "/".join(ckpt_path.parts[:2]) + file_path = "/".join(ckpt_path.parts[2:]) + + if file_path.startswith("blob/"): + file_path = file_path[len("blob/") :] + + if file_path.startswith("main/"): + file_path = file_path[len("main/") :] + + pretrained_model_link_or_path = hf_hub_download( + repo_id, + filename=file_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + force_download=force_download, + ) + + if from_safetensors: + from safetensors import safe_open + + checkpoint = {} + with safe_open(pretrained_model_link_or_path, framework="pt", device="cpu") as f: + for key in f.keys(): + checkpoint[key] = f.get_tensor(key) + else: + checkpoint = torch.load(pretrained_model_link_or_path, map_location="cpu") + + if "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + if config_file is None: + config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" + config_file = BytesIO(requests.get(config_url).content) + + original_config = OmegaConf.load(config_file) + + # default to sd-v1-5 + image_size = image_size or 512 + + vae_config = create_vae_diffusers_config(original_config, image_size=image_size) + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + + if scaling_factor is None: + if ( + "model" in original_config + and "params" in original_config.model + and "scale_factor" in original_config.model.params + ): + vae_scaling_factor = original_config.model.params.scale_factor + else: + vae_scaling_factor = 0.18215 # default SD scaling factor + + vae_config["scaling_factor"] = vae_scaling_factor + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + vae = AutoencoderKL(**vae_config) + + if is_accelerate_available(): + for param_name, param in converted_vae_checkpoint.items(): + set_module_tensor_to_device(vae, param_name, "cpu", value=param) + else: + vae.load_state_dict(converted_vae_checkpoint) + + if torch_dtype is not None: + vae.to(torch_dtype=torch_dtype) + + return vae + + +class FromOriginalControlnetMixin: + @classmethod + def from_single_file(cls, pretrained_model_link_or_path, **kwargs): + r""" + Instantiate a [`ControlNetModel`] from pretrained controlnet weights saved in the original `.ckpt` or + `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default. + + Parameters: + pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + - A link to the `.ckpt` file (for example + `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. + - A path to a *file* containing all pipeline weights. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to True, the model + won't be downloaded from the Hub. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + image_size (`int`, *optional*, defaults to 512): + The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable + Diffusion v2 base model. Use 768 for Stable Diffusion v2. + upcast_attention (`bool`, *optional*, defaults to `None`): + Whether the attention computation should always be upcasted. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (for example the pipeline components of the + specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` + method. See example below for more information. + + Examples: + + ```py + from diffusers import StableDiffusionControlnetPipeline, ControlNetModel + + url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path + model = ControlNetModel.from_single_file(url) + + url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path + pipe = StableDiffusionControlnetPipeline.from_single_file(url, controlnet=controlnet) + ``` + """ + # import here to avoid circular dependency + from .pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt + + config_file = kwargs.pop("config_file", None) + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + num_in_channels = kwargs.pop("num_in_channels", None) + use_linear_projection = kwargs.pop("use_linear_projection", None) + revision = kwargs.pop("revision", None) + extract_ema = kwargs.pop("extract_ema", False) + image_size = kwargs.pop("image_size", None) + upcast_attention = kwargs.pop("upcast_attention", None) + + torch_dtype = kwargs.pop("torch_dtype", None) + + use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) + + file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] + from_safetensors = file_extension == "safetensors" + + if from_safetensors and use_safetensors is False: + raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.") + + # remove huggingface url + for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]: + if pretrained_model_link_or_path.startswith(prefix): + pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :] + + # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained + ckpt_path = Path(pretrained_model_link_or_path) + if not ckpt_path.is_file(): + # get repo_id and (potentially nested) file path of ckpt in repo + repo_id = "/".join(ckpt_path.parts[:2]) + file_path = "/".join(ckpt_path.parts[2:]) + + if file_path.startswith("blob/"): + file_path = file_path[len("blob/") :] + + if file_path.startswith("main/"): + file_path = file_path[len("main/") :] + + pretrained_model_link_or_path = hf_hub_download( + repo_id, + filename=file_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + force_download=force_download, + ) + + if config_file is None: + config_url = "https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml" + config_file = BytesIO(requests.get(config_url).content) + + image_size = image_size or 512 + + controlnet = download_controlnet_from_original_ckpt( + pretrained_model_link_or_path, + original_config_file=config_file, + image_size=image_size, + extract_ema=extract_ema, + num_in_channels=num_in_channels, + upcast_attention=upcast_attention, + from_safetensors=from_safetensors, + use_linear_projection=use_linear_projection, + ) + + if torch_dtype is not None: + controlnet.to(torch_dtype=torch_dtype) + + return controlnet diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index c4fdf751ad..2390d2bc58 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -18,6 +18,7 @@ import torch import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import FromOriginalVAEMixin from ..utils import BaseOutput, apply_forward_hook from .attention_processor import AttentionProcessor, AttnProcessor from .modeling_utils import ModelMixin @@ -38,7 +39,7 @@ class AutoencoderKLOutput(BaseOutput): latent_dist: "DiagonalGaussianDistribution" -class AutoencoderKL(ModelMixin, ConfigMixin): +class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index acebee4b40..354cd5851d 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -19,6 +19,7 @@ from torch import nn from torch.nn import functional as F from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import FromOriginalControlnetMixin from ..utils import BaseOutput, logging from .attention_processor import AttentionProcessor, AttnProcessor from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps @@ -100,7 +101,7 @@ class ControlNetConditioningEmbedding(nn.Module): return embedding -class ControlNetModel(ModelMixin, ConfigMixin): +class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): """ A ControlNet model. diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index dc6e22a7d8..2fcc0c67ed 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -24,7 +24,7 @@ import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor -from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -90,7 +90,9 @@ EXAMPLE_DOC_STRING = """ """ -class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): +class StableDiffusionControlNetPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 30eba4fa83..c29a00a354 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -24,7 +24,7 @@ import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor -from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -116,7 +116,9 @@ def prepare_image(image): return image -class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): +class StableDiffusionControlNetImg2ImgPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 2d850f2043..b7481a0d43 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -25,7 +25,7 @@ import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor -from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -222,7 +222,9 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False return mask, masked_image -class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): +class StableDiffusionControlNetInpaintPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 747879192c..0eeb80f12d 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -621,8 +621,8 @@ def convert_ldm_unet_checkpoint( def convert_ldm_vae_checkpoint(checkpoint, config): # extract state dict for VAE vae_state_dict = {} - vae_key = "first_stage_model." keys = list(checkpoint.keys()) + vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else "" for key in keys: if key.startswith(vae_key): vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) @@ -1064,7 +1064,7 @@ def convert_controlnet_checkpoint( if cross_attention_dim is not None: ctrlnet_config["cross_attention_dim"] = cross_attention_dim - controlnet_model = ControlNetModel(**ctrlnet_config) + controlnet = ControlNetModel(**ctrlnet_config) # Some controlnet ckpt files are distributed independently from the rest of the # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ @@ -1082,9 +1082,9 @@ def convert_controlnet_checkpoint( skip_extract_state_dict=skip_extract_state_dict, ) - controlnet_model.load_state_dict(converted_ctrl_checkpoint) + controlnet.load_state_dict(converted_ctrl_checkpoint) - return controlnet_model + return controlnet def download_from_original_stable_diffusion_ckpt( @@ -1181,7 +1181,7 @@ def download_from_original_stable_diffusion_ckpt( ) if pipeline_class is None: - pipeline_class = StableDiffusionPipeline + pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline if prediction_type == "v-prediction": prediction_type = "v_prediction" @@ -1288,8 +1288,7 @@ def download_from_original_stable_diffusion_ckpt( if controlnet is None: controlnet = "control_stage_config" in original_config.model.params - if controlnet: - controlnet_model = convert_controlnet_checkpoint( + controlnet = convert_controlnet_checkpoint( checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema ) @@ -1400,13 +1399,13 @@ def download_from_original_stable_diffusion_ckpt( if stable_unclip is None: if controlnet: - pipe = StableDiffusionControlNetPipeline( + pipe = pipeline_class( vae=vae, text_encoder=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler, - controlnet=controlnet_model, + controlnet=controlnet, safety_checker=None, feature_extractor=None, requires_safety_checker=False, @@ -1503,12 +1502,12 @@ def download_from_original_stable_diffusion_ckpt( feature_extractor = None if controlnet: - pipe = StableDiffusionControlNetPipeline( + pipe = pipeline_class( vae=vae, text_encoder=text_model, tokenizer=tokenizer, unet=unet, - controlnet=controlnet_model, + controlnet=controlnet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, @@ -1623,7 +1622,7 @@ def download_controlnet_from_original_ckpt( if "control_stage_config" not in original_config.model.params: raise ValueError("`control_stage_config` not present in original config") - controlnet_model = convert_controlnet_checkpoint( + controlnet = convert_controlnet_checkpoint( checkpoint, original_config, checkpoint_path, @@ -1634,4 +1633,4 @@ def download_controlnet_from_original_ckpt( cross_attention_dim=cross_attention_dim, ) - return controlnet_model + return controlnet diff --git a/tests/models/test_models_vae.py b/tests/models/test_models_vae.py index 08b030bbf9..245a5bb89b 100644 --- a/tests/models/test_models_vae.py +++ b/tests/models/test_models_vae.py @@ -199,7 +199,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase): torch_dtype=torch_dtype, revision=revision, ) - model.to(torch_device).eval() + model.to(torch_device) return model @@ -383,3 +383,22 @@ class AutoencoderKLIntegrationTests(unittest.TestCase): tolerance = 3e-3 if torch_device != "mps" else 1e-2 assert torch_all_close(output_slice, expected_output_slice, atol=tolerance) + + def test_stable_diffusion_model_local(self): + model_id = "stabilityai/sd-vae-ft-mse" + model_1 = AutoencoderKL.from_pretrained(model_id).to(torch_device) + + url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" + model_2 = AutoencoderKL.from_single_file(url).to(torch_device) + image = self.get_sd_image(33) + + with torch.no_grad(): + sample_1 = model_1(image).sample + sample_2 = model_2(image).sample + + assert sample_1.shape == sample_2.shape + + output_slice_1 = sample_1[-1, -2:, -2:, :2].flatten().float().cpu() + output_slice_2 = sample_2[-1, -2:, -2:, :2].flatten().float().cpu() + + assert torch_all_close(output_slice_1, output_slice_2, atol=3e-3) diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py index a548983c38..902b4bce97 100644 --- a/tests/pipelines/controlnet/test_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -752,6 +752,42 @@ class ControlNetPipelineSlowTests(unittest.TestCase): expected_slice = np.array([0.1338, 0.1597, 0.1202, 0.1687, 0.1377, 0.1017, 0.2070, 0.1574, 0.1348]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_load_local(self): + controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny") + pipe_1 = StableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet + ) + + controlnet = ControlNetModel.from_single_file( + "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" + ) + pipe_2 = StableDiffusionControlNetPipeline.from_single_file( + "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors", + safety_checker=None, + controlnet=controlnet, + ) + pipes = [pipe_1, pipe_2] + images = [] + + for pipe in pipes: + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "bird" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ) + + output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3) + images.append(output.images[0]) + + del pipe + gc.collect() + torch.cuda.empty_cache() + + assert np.abs(images[0] - images[1]).sum() < 1e-3 + @slow @require_torch_gpu diff --git a/tests/pipelines/controlnet/test_controlnet_img2img.py b/tests/pipelines/controlnet/test_controlnet_img2img.py index c46593f03e..2b9ec7e463 100644 --- a/tests/pipelines/controlnet/test_controlnet_img2img.py +++ b/tests/pipelines/controlnet/test_controlnet_img2img.py @@ -401,3 +401,49 @@ class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase): ) assert np.abs(expected_image - image).max() < 9e-2 + + def test_load_local(self): + controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny") + pipe_1 = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet + ) + + controlnet = ControlNetModel.from_single_file( + "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" + ) + pipe_2 = StableDiffusionControlNetImg2ImgPipeline.from_single_file( + "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors", + safety_checker=None, + controlnet=controlnet, + ) + control_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ).resize((512, 512)) + image = load_image( + "https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png" + ).resize((512, 512)) + + pipes = [pipe_1, pipe_2] + images = [] + for pipe in pipes: + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "bird" + output = pipe( + prompt, + image=image, + control_image=control_image, + strength=0.9, + generator=generator, + output_type="np", + num_inference_steps=3, + ) + images.append(output.images[0]) + + del pipe + gc.collect() + torch.cuda.empty_cache() + + assert np.abs(images[0] - images[1]).sum() < 1e-3 diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py index cf423f4c49..cb9b53e612 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py @@ -543,3 +543,54 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase): ) assert np.abs(expected_image - image).max() < 9e-2 + + def test_load_local(self): + controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny") + pipe_1 = StableDiffusionControlNetInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet + ) + + controlnet = ControlNetModel.from_single_file( + "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" + ) + pipe_2 = StableDiffusionControlNetInpaintPipeline.from_single_file( + "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors", + safety_checker=None, + controlnet=controlnet, + ) + control_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ).resize((512, 512)) + image = load_image( + "https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png" + ).resize((512, 512)) + mask_image = load_image( + "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main" + "/stable_diffusion_inpaint/input_bench_mask.png" + ).resize((512, 512)) + + pipes = [pipe_1, pipe_2] + images = [] + for pipe in pipes: + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "bird" + output = pipe( + prompt, + image=image, + control_image=control_image, + mask_image=mask_image, + strength=0.9, + generator=generator, + output_type="np", + num_inference_steps=3, + ) + images.append(output.images[0]) + + del pipe + gc.collect() + torch.cuda.empty_cache() + + assert np.abs(images[0] - images[1]).sum() < 1e-3