mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add controlnet and vae from single file (#4084)
* Add controlnet from single file * Updates * make style * finish * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
committed by
GitHub
parent
470f51cd26
commit
6b1abba18d
@@ -35,3 +35,11 @@ Adapters (textual inversion, LoRA, hypernetworks) allow you to modify a diffusio
|
||||
## FromSingleFileMixin
|
||||
|
||||
[[autodoc]] loaders.FromSingleFileMixin
|
||||
|
||||
## FromOriginalControlnetMixin
|
||||
|
||||
[[autodoc]] loaders.FromOriginalControlnetMixin
|
||||
|
||||
## FromOriginalVAEMixin
|
||||
|
||||
[[autodoc]] loaders.FromOriginalVAEMixin
|
||||
|
||||
@@ -6,6 +6,18 @@ The abstract from the paper is:
|
||||
|
||||
*How can we perform efficient inference and learning in directed probabilistic models, in the presence of continuous latent variables with intractable posterior distributions, and large datasets? We introduce a stochastic variational inference and learning algorithm that scales to large datasets and, under some mild differentiability conditions, even works in the intractable case. Our contributions are two-fold. First, we show that a reparameterization of the variational lower bound yields a lower bound estimator that can be straightforwardly optimized using standard stochastic gradient methods. Second, we show that for i.i.d. datasets with continuous latent variables per datapoint, posterior inference can be made especially efficient by fitting an approximate inference model (also called a recognition model) to the intractable posterior using the proposed lower bound estimator. Theoretical advantages are reflected in experimental results.*
|
||||
|
||||
## Loading from the original format
|
||||
|
||||
By default the [`AutoencoderKL`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded
|
||||
from the original format using [`FromOriginalVAEMixin.from_single_file`] as follows:
|
||||
|
||||
```py
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
|
||||
model = AutoencoderKL.from_single_file(url)
|
||||
```
|
||||
|
||||
## AutoencoderKL
|
||||
|
||||
[[autodoc]] AutoencoderKL
|
||||
@@ -28,4 +40,4 @@ The abstract from the paper is:
|
||||
|
||||
## FlaxDecoderOutput
|
||||
|
||||
[[autodoc]] models.vae_flax.FlaxDecoderOutput
|
||||
[[autodoc]] models.vae_flax.FlaxDecoderOutput
|
||||
|
||||
@@ -6,6 +6,21 @@ The abstract from the paper is:
|
||||
|
||||
*We present a neural network structure, ControlNet, to control pretrained large diffusion models to support additional input conditions. The ControlNet learns task-specific conditions in an end-to-end way, and the learning is robust even when the training dataset is small (< 50k). Moreover, training a ControlNet is as fast as fine-tuning a diffusion model, and the model can be trained on a personal devices. Alternatively, if powerful computation clusters are available, the model can scale to large amounts (millions to billions) of data. We report that large diffusion models like Stable Diffusion can be augmented with ControlNets to enable conditional inputs like edge maps, segmentation maps, keypoints, etc. This may enrich the methods to control large diffusion models and further facilitate related applications.*
|
||||
|
||||
## Loading from the original format
|
||||
|
||||
By default the [`ControlNetModel`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded
|
||||
from the original format using [`FromOriginalControlnetMixin.from_single_file`] as follows:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionControlnetPipeline, ControlNetModel
|
||||
|
||||
url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
|
||||
controlnet = ControlNetModel.from_single_file(url)
|
||||
|
||||
url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
|
||||
pipe = StableDiffusionControlnetPipeline.from_single_file(url, controlnet=controlnet)
|
||||
```
|
||||
|
||||
## ControlNetModel
|
||||
|
||||
[[autodoc]] ControlNetModel
|
||||
@@ -20,4 +35,4 @@ The abstract from the paper is:
|
||||
|
||||
## FlaxControlNetOutput
|
||||
|
||||
[[autodoc]] models.controlnet_flax.FlaxControlNetOutput
|
||||
[[autodoc]] models.controlnet_flax.FlaxControlNetOutput
|
||||
|
||||
@@ -14,9 +14,12 @@
|
||||
import os
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -42,10 +45,13 @@ from .utils import (
|
||||
HF_HUB_OFFLINE,
|
||||
_get_model_file,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_omegaconf_available,
|
||||
is_safetensors_available,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
)
|
||||
from .utils.import_utils import BACKENDS_MAPPING
|
||||
|
||||
|
||||
if is_safetensors_available():
|
||||
@@ -54,6 +60,9 @@ if is_safetensors_available():
|
||||
if is_transformers_available():
|
||||
from transformers import CLIPTextModel, PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -1319,8 +1328,8 @@ class FromSingleFileMixin:
|
||||
@classmethod
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` format. The pipeline
|
||||
is set in evaluation mode (`model.eval()`) by default.
|
||||
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
|
||||
format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
@@ -1430,6 +1439,7 @@ class FromSingleFileMixin:
|
||||
load_safety_checker = kwargs.pop("load_safety_checker", True)
|
||||
prediction_type = kwargs.pop("prediction_type", None)
|
||||
text_encoder = kwargs.pop("text_encoder", None)
|
||||
controlnet = kwargs.pop("controlnet", None)
|
||||
tokenizer = kwargs.pop("tokenizer", None)
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
@@ -1446,11 +1456,18 @@ class FromSingleFileMixin:
|
||||
# TODO: For now we only support stable diffusion
|
||||
stable_unclip = None
|
||||
model_type = None
|
||||
controlnet = False
|
||||
|
||||
if pipeline_name == "StableDiffusionControlNetPipeline":
|
||||
if pipeline_name in [
|
||||
"StableDiffusionControlNetPipeline",
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
"StableDiffusionControlNetInpaintPipeline",
|
||||
]:
|
||||
from .models.controlnet import ControlNetModel
|
||||
from .pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||
|
||||
# Model type will be inferred from the checkpoint.
|
||||
controlnet = True
|
||||
if not isinstance(controlnet, (ControlNetModel, MultiControlNetModel)):
|
||||
raise ValueError("ControlNet needs to be passed if loading from ControlNet pipeline.")
|
||||
elif "StableDiffusion" in pipeline_name:
|
||||
# Model type will be inferred from the checkpoint.
|
||||
pass
|
||||
@@ -1519,3 +1536,339 @@ class FromSingleFileMixin:
|
||||
pipe.to(torch_dtype=torch_dtype)
|
||||
|
||||
return pipe
|
||||
|
||||
|
||||
class FromOriginalVAEMixin:
|
||||
@classmethod
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`AutoencoderKL`] from pretrained controlnet weights saved in the original `.ckpt` or
|
||||
`.safetensors` format. The pipeline is format. The pipeline is set in evaluation mode (`model.eval()`) by
|
||||
default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
upcast_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether the attention computation should always be upcasted.
|
||||
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
||||
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
||||
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
||||
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
||||
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z
|
||||
= 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution
|
||||
Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you want to load
|
||||
a VAE that does accompany a stable diffusion model of v2 or higher or SDXL.
|
||||
|
||||
</Tip>
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
|
||||
model = AutoencoderKL.from_single_file(url)
|
||||
```
|
||||
"""
|
||||
if not is_omegaconf_available():
|
||||
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from .models import AutoencoderKL
|
||||
|
||||
# import here to avoid circular dependency
|
||||
from .pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
convert_ldm_vae_checkpoint,
|
||||
create_vae_diffusers_config,
|
||||
)
|
||||
|
||||
config_file = kwargs.pop("config_file", None)
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
scaling_factor = kwargs.pop("scaling_factor", None)
|
||||
kwargs.pop("upcast_attention", None)
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
|
||||
|
||||
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
||||
from_safetensors = file_extension == "safetensors"
|
||||
|
||||
if from_safetensors and use_safetensors is False:
|
||||
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
|
||||
|
||||
# remove huggingface url
|
||||
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
|
||||
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
||||
ckpt_path = Path(pretrained_model_link_or_path)
|
||||
if not ckpt_path.is_file():
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = "/".join(ckpt_path.parts[:2])
|
||||
file_path = "/".join(ckpt_path.parts[2:])
|
||||
|
||||
if file_path.startswith("blob/"):
|
||||
file_path = file_path[len("blob/") :]
|
||||
|
||||
if file_path.startswith("main/"):
|
||||
file_path = file_path[len("main/") :]
|
||||
|
||||
pretrained_model_link_or_path = hf_hub_download(
|
||||
repo_id,
|
||||
filename=file_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
|
||||
if from_safetensors:
|
||||
from safetensors import safe_open
|
||||
|
||||
checkpoint = {}
|
||||
with safe_open(pretrained_model_link_or_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
checkpoint[key] = f.get_tensor(key)
|
||||
else:
|
||||
checkpoint = torch.load(pretrained_model_link_or_path, map_location="cpu")
|
||||
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
if config_file is None:
|
||||
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
config_file = BytesIO(requests.get(config_url).content)
|
||||
|
||||
original_config = OmegaConf.load(config_file)
|
||||
|
||||
# default to sd-v1-5
|
||||
image_size = image_size or 512
|
||||
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
|
||||
if scaling_factor is None:
|
||||
if (
|
||||
"model" in original_config
|
||||
and "params" in original_config.model
|
||||
and "scale_factor" in original_config.model.params
|
||||
):
|
||||
vae_scaling_factor = original_config.model.params.scale_factor
|
||||
else:
|
||||
vae_scaling_factor = 0.18215 # default SD scaling factor
|
||||
|
||||
vae_config["scaling_factor"] = vae_scaling_factor
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
|
||||
if is_accelerate_available():
|
||||
for param_name, param in converted_vae_checkpoint.items():
|
||||
set_module_tensor_to_device(vae, param_name, "cpu", value=param)
|
||||
else:
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
|
||||
if torch_dtype is not None:
|
||||
vae.to(torch_dtype=torch_dtype)
|
||||
|
||||
return vae
|
||||
|
||||
|
||||
class FromOriginalControlnetMixin:
|
||||
@classmethod
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`ControlNetModel`] from pretrained controlnet weights saved in the original `.ckpt` or
|
||||
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
upcast_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether the attention computation should always be upcasted.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionControlnetPipeline, ControlNetModel
|
||||
|
||||
url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
|
||||
model = ControlNetModel.from_single_file(url)
|
||||
|
||||
url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
|
||||
pipe = StableDiffusionControlnetPipeline.from_single_file(url, controlnet=controlnet)
|
||||
```
|
||||
"""
|
||||
# import here to avoid circular dependency
|
||||
from .pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt
|
||||
|
||||
config_file = kwargs.pop("config_file", None)
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
num_in_channels = kwargs.pop("num_in_channels", None)
|
||||
use_linear_projection = kwargs.pop("use_linear_projection", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
extract_ema = kwargs.pop("extract_ema", False)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
upcast_attention = kwargs.pop("upcast_attention", None)
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
|
||||
|
||||
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
||||
from_safetensors = file_extension == "safetensors"
|
||||
|
||||
if from_safetensors and use_safetensors is False:
|
||||
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
|
||||
|
||||
# remove huggingface url
|
||||
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
|
||||
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
||||
ckpt_path = Path(pretrained_model_link_or_path)
|
||||
if not ckpt_path.is_file():
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = "/".join(ckpt_path.parts[:2])
|
||||
file_path = "/".join(ckpt_path.parts[2:])
|
||||
|
||||
if file_path.startswith("blob/"):
|
||||
file_path = file_path[len("blob/") :]
|
||||
|
||||
if file_path.startswith("main/"):
|
||||
file_path = file_path[len("main/") :]
|
||||
|
||||
pretrained_model_link_or_path = hf_hub_download(
|
||||
repo_id,
|
||||
filename=file_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
|
||||
if config_file is None:
|
||||
config_url = "https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml"
|
||||
config_file = BytesIO(requests.get(config_url).content)
|
||||
|
||||
image_size = image_size or 512
|
||||
|
||||
controlnet = download_controlnet_from_original_ckpt(
|
||||
pretrained_model_link_or_path,
|
||||
original_config_file=config_file,
|
||||
image_size=image_size,
|
||||
extract_ema=extract_ema,
|
||||
num_in_channels=num_in_channels,
|
||||
upcast_attention=upcast_attention,
|
||||
from_safetensors=from_safetensors,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
|
||||
if torch_dtype is not None:
|
||||
controlnet.to(torch_dtype=torch_dtype)
|
||||
|
||||
return controlnet
|
||||
|
||||
@@ -18,6 +18,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import FromOriginalVAEMixin
|
||||
from ..utils import BaseOutput, apply_forward_hook
|
||||
from .attention_processor import AttentionProcessor, AttnProcessor
|
||||
from .modeling_utils import ModelMixin
|
||||
@@ -38,7 +39,7 @@ class AutoencoderKLOutput(BaseOutput):
|
||||
latent_dist: "DiagonalGaussianDistribution"
|
||||
|
||||
|
||||
class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import FromOriginalControlnetMixin
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import AttentionProcessor, AttnProcessor
|
||||
from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
||||
@@ -100,7 +101,7 @@ class ControlNetConditioningEmbedding(nn.Module):
|
||||
return embedding
|
||||
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
"""
|
||||
A ControlNet model.
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -90,7 +90,9 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
class StableDiffusionControlNetPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -116,7 +116,9 @@ def prepare_image(image):
|
||||
return image
|
||||
|
||||
|
||||
class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
class StableDiffusionControlNetImg2ImgPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -222,7 +222,9 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
class StableDiffusionControlNetInpaintPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
|
||||
|
||||
|
||||
@@ -621,8 +621,8 @@ def convert_ldm_unet_checkpoint(
|
||||
def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
# extract state dict for VAE
|
||||
vae_state_dict = {}
|
||||
vae_key = "first_stage_model."
|
||||
keys = list(checkpoint.keys())
|
||||
vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else ""
|
||||
for key in keys:
|
||||
if key.startswith(vae_key):
|
||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||
@@ -1064,7 +1064,7 @@ def convert_controlnet_checkpoint(
|
||||
if cross_attention_dim is not None:
|
||||
ctrlnet_config["cross_attention_dim"] = cross_attention_dim
|
||||
|
||||
controlnet_model = ControlNetModel(**ctrlnet_config)
|
||||
controlnet = ControlNetModel(**ctrlnet_config)
|
||||
|
||||
# Some controlnet ckpt files are distributed independently from the rest of the
|
||||
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
|
||||
@@ -1082,9 +1082,9 @@ def convert_controlnet_checkpoint(
|
||||
skip_extract_state_dict=skip_extract_state_dict,
|
||||
)
|
||||
|
||||
controlnet_model.load_state_dict(converted_ctrl_checkpoint)
|
||||
controlnet.load_state_dict(converted_ctrl_checkpoint)
|
||||
|
||||
return controlnet_model
|
||||
return controlnet
|
||||
|
||||
|
||||
def download_from_original_stable_diffusion_ckpt(
|
||||
@@ -1181,7 +1181,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
)
|
||||
|
||||
if pipeline_class is None:
|
||||
pipeline_class = StableDiffusionPipeline
|
||||
pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline
|
||||
|
||||
if prediction_type == "v-prediction":
|
||||
prediction_type = "v_prediction"
|
||||
@@ -1288,8 +1288,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
if controlnet is None:
|
||||
controlnet = "control_stage_config" in original_config.model.params
|
||||
|
||||
if controlnet:
|
||||
controlnet_model = convert_controlnet_checkpoint(
|
||||
controlnet = convert_controlnet_checkpoint(
|
||||
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
|
||||
)
|
||||
|
||||
@@ -1400,13 +1399,13 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
|
||||
if stable_unclip is None:
|
||||
if controlnet:
|
||||
pipe = StableDiffusionControlNetPipeline(
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
controlnet=controlnet_model,
|
||||
controlnet=controlnet,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
@@ -1503,12 +1502,12 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
feature_extractor = None
|
||||
|
||||
if controlnet:
|
||||
pipe = StableDiffusionControlNetPipeline(
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
controlnet=controlnet_model,
|
||||
controlnet=controlnet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
@@ -1623,7 +1622,7 @@ def download_controlnet_from_original_ckpt(
|
||||
if "control_stage_config" not in original_config.model.params:
|
||||
raise ValueError("`control_stage_config` not present in original config")
|
||||
|
||||
controlnet_model = convert_controlnet_checkpoint(
|
||||
controlnet = convert_controlnet_checkpoint(
|
||||
checkpoint,
|
||||
original_config,
|
||||
checkpoint_path,
|
||||
@@ -1634,4 +1633,4 @@ def download_controlnet_from_original_ckpt(
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
|
||||
return controlnet_model
|
||||
return controlnet
|
||||
|
||||
@@ -199,7 +199,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
torch_dtype=torch_dtype,
|
||||
revision=revision,
|
||||
)
|
||||
model.to(torch_device).eval()
|
||||
model.to(torch_device)
|
||||
|
||||
return model
|
||||
|
||||
@@ -383,3 +383,22 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
|
||||
tolerance = 3e-3 if torch_device != "mps" else 1e-2
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
|
||||
|
||||
def test_stable_diffusion_model_local(self):
|
||||
model_id = "stabilityai/sd-vae-ft-mse"
|
||||
model_1 = AutoencoderKL.from_pretrained(model_id).to(torch_device)
|
||||
|
||||
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
|
||||
model_2 = AutoencoderKL.from_single_file(url).to(torch_device)
|
||||
image = self.get_sd_image(33)
|
||||
|
||||
with torch.no_grad():
|
||||
sample_1 = model_1(image).sample
|
||||
sample_2 = model_2(image).sample
|
||||
|
||||
assert sample_1.shape == sample_2.shape
|
||||
|
||||
output_slice_1 = sample_1[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
output_slice_2 = sample_2[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
|
||||
assert torch_all_close(output_slice_1, output_slice_2, atol=3e-3)
|
||||
|
||||
@@ -752,6 +752,42 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
|
||||
expected_slice = np.array([0.1338, 0.1597, 0.1202, 0.1687, 0.1377, 0.1017, 0.2070, 0.1574, 0.1348])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_load_local(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
|
||||
pipe_1 = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
|
||||
controlnet = ControlNetModel.from_single_file(
|
||||
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
|
||||
)
|
||||
pipe_2 = StableDiffusionControlNetPipeline.from_single_file(
|
||||
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
|
||||
safety_checker=None,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
pipes = [pipe_1, pipe_2]
|
||||
images = []
|
||||
|
||||
for pipe in pipes:
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "bird"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
)
|
||||
|
||||
output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3)
|
||||
images.append(output.images[0])
|
||||
|
||||
del pipe
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
assert np.abs(images[0] - images[1]).sum() < 1e-3
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -401,3 +401,49 @@ class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
assert np.abs(expected_image - image).max() < 9e-2
|
||||
|
||||
def test_load_local(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
|
||||
pipe_1 = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
|
||||
controlnet = ControlNetModel.from_single_file(
|
||||
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
|
||||
)
|
||||
pipe_2 = StableDiffusionControlNetImg2ImgPipeline.from_single_file(
|
||||
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
|
||||
safety_checker=None,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
).resize((512, 512))
|
||||
image = load_image(
|
||||
"https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png"
|
||||
).resize((512, 512))
|
||||
|
||||
pipes = [pipe_1, pipe_2]
|
||||
images = []
|
||||
for pipe in pipes:
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "bird"
|
||||
output = pipe(
|
||||
prompt,
|
||||
image=image,
|
||||
control_image=control_image,
|
||||
strength=0.9,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
num_inference_steps=3,
|
||||
)
|
||||
images.append(output.images[0])
|
||||
|
||||
del pipe
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
assert np.abs(images[0] - images[1]).sum() < 1e-3
|
||||
|
||||
@@ -543,3 +543,54 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
assert np.abs(expected_image - image).max() < 9e-2
|
||||
|
||||
def test_load_local(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
|
||||
pipe_1 = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
|
||||
controlnet = ControlNetModel.from_single_file(
|
||||
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
|
||||
)
|
||||
pipe_2 = StableDiffusionControlNetInpaintPipeline.from_single_file(
|
||||
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
|
||||
safety_checker=None,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
).resize((512, 512))
|
||||
image = load_image(
|
||||
"https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png"
|
||||
).resize((512, 512))
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint/input_bench_mask.png"
|
||||
).resize((512, 512))
|
||||
|
||||
pipes = [pipe_1, pipe_2]
|
||||
images = []
|
||||
for pipe in pipes:
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "bird"
|
||||
output = pipe(
|
||||
prompt,
|
||||
image=image,
|
||||
control_image=control_image,
|
||||
mask_image=mask_image,
|
||||
strength=0.9,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
num_inference_steps=3,
|
||||
)
|
||||
images.append(output.images[0])
|
||||
|
||||
del pipe
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
assert np.abs(images[0] - images[1]).sum() < 1e-3
|
||||
|
||||
Reference in New Issue
Block a user