From b65861800e1cdc71ef8e666df56ef87ffbb11d86 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 19 Jan 2024 09:52:08 +0000 Subject: [PATCH] updaet --- src/diffusers/loaders/__init__.py | 4 + src/diffusers/loaders/autoencoder.py | 123 +++++++++++++++++ src/diffusers/loaders/controlnet.py | 127 ++++++++++++++++++ src/diffusers/loaders/single_file.py | 89 +++--------- src/diffusers/loaders/single_file_utils.py | 87 ++++++++++-- .../models/autoencoders/autoencoder_kl.py | 4 +- src/diffusers/models/controlnet.py | 4 +- 7 files changed, 354 insertions(+), 84 deletions(-) create mode 100644 src/diffusers/loaders/autoencoder.py create mode 100644 src/diffusers/loaders/controlnet.py diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 675246e408..58e425359e 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -56,6 +56,8 @@ _import_structure = {} if is_torch_available(): _import_structure["unet"] = ["UNet2DConditionLoadersMixin"] _import_structure["utils"] = ["AttnProcsLayers"] + _import_structure["controlnet"] = ["FromOriginalControlnetMixin"] + _import_structure["autoencoder"] = ["FromOriginalVAEMixin"] if is_transformers_available(): _import_structure["single_file"] = ["FromSingleFileMixin"] @@ -68,6 +70,8 @@ _import_structure["peft"] = ["PeftAdapterMixin"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): + from .autoencoder import FromOriginalVAEMixin + from .controlnet import FromOriginalControlnetMixin from .unet import UNet2DConditionLoadersMixin from .utils import AttnProcsLayers diff --git a/src/diffusers/loaders/autoencoder.py b/src/diffusers/loaders/autoencoder.py new file mode 100644 index 0000000000..8936d4f0be --- /dev/null +++ b/src/diffusers/loaders/autoencoder.py @@ -0,0 +1,123 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from huggingface_hub.utils import validate_hf_hub_args + +from .single_file_utils import ( + create_diffusers_vae_model_from_ldm, + fetch_ldm_config_and_checkpoint, +) + + +class FromOriginalVAEMixin: + """ + Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`]. + """ + + @classmethod + @validate_hf_hub_args + def from_single_file(cls, pretrained_model_link_or_path, **kwargs): + r""" + Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or + `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default. + + Parameters: + pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + - A link to the `.ckpt` file (for example + `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. + - A path to a *file* containing all pipeline weights. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to True, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + image_size (`int`, *optional*, defaults to 512): + The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable + Diffusion v2 base model. Use 768 for Stable Diffusion v2. + upcast_attention (`bool`, *optional*, defaults to `None`): + Whether the attention computation should always be upcasted. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (for example the pipeline components of the + specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` + method. See example below for more information. + + Examples: + + ```py + from diffusers import StableDiffusionControlNetPipeline, ControlNetModel + + url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path + model = ControlNetModel.from_single_file(url) + + url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path + pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet) + ``` + """ + original_config_file = kwargs.pop("original_config_file", None) + resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + cache_dir = kwargs.pop("cache_dir", None) + local_files_only = kwargs.pop("local_files_only", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + use_safetensors = kwargs.pop("use_safetensors", True) + + class_name = cls.__name__ + original_config, checkpoint = fetch_ldm_config_and_checkpoint( + pretrained_model_link_or_path=pretrained_model_link_or_path, + class_name=class_name, + original_config_file=original_config_file, + resume_download=resume_download, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + local_files_only=local_files_only, + use_safetensors=use_safetensors, + cache_dir=cache_dir, + ) + + image_size = kwargs.pop("image_size", None) + component = create_diffusers_vae_model_from_ldm(class_name, original_config, checkpoint, image_size=image_size) + vae = component["vae"] + if torch_dtype is not None: + vae = vae.to(torch_dtype) + + return vae diff --git a/src/diffusers/loaders/controlnet.py b/src/diffusers/loaders/controlnet.py new file mode 100644 index 0000000000..88008f006f --- /dev/null +++ b/src/diffusers/loaders/controlnet.py @@ -0,0 +1,127 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from huggingface_hub.utils import validate_hf_hub_args + +from .single_file_utils import ( + create_diffusers_controlnet_model_from_ldm, + fetch_ldm_config_and_checkpoint, +) + + +class FromOriginalControlnetMixin: + """ + Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`]. + """ + + @classmethod + @validate_hf_hub_args + def from_single_file(cls, pretrained_model_link_or_path, **kwargs): + r""" + Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or + `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default. + + Parameters: + pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + - A link to the `.ckpt` file (for example + `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. + - A path to a *file* containing all pipeline weights. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to True, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + image_size (`int`, *optional*, defaults to 512): + The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable + Diffusion v2 base model. Use 768 for Stable Diffusion v2. + upcast_attention (`bool`, *optional*, defaults to `None`): + Whether the attention computation should always be upcasted. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (for example the pipeline components of the + specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` + method. See example below for more information. + + Examples: + + ```py + from diffusers import StableDiffusionControlNetPipeline, ControlNetModel + + url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path + model = ControlNetModel.from_single_file(url) + + url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path + pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet) + ``` + """ + original_config_file = kwargs.pop("original_config_file", None) + resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + cache_dir = kwargs.pop("cache_dir", None) + local_files_only = kwargs.pop("local_files_only", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + use_safetensors = kwargs.pop("use_safetensors", True) + + class_name = cls.__name__ + original_config, checkpoint = fetch_ldm_config_and_checkpoint( + pretrained_model_link_or_path=pretrained_model_link_or_path, + class_name=class_name, + original_config_file=original_config_file, + resume_download=resume_download, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + local_files_only=local_files_only, + use_safetensors=use_safetensors, + cache_dir=cache_dir, + ) + + upcast_attention = kwargs.pop("upcast_attention", False) + image_size = kwargs.pop("image_size", None) + + component = create_diffusers_controlnet_model_from_ldm( + class_name, original_config, checkpoint, upcast_attention=upcast_attention, image_size=image_size + ) + controlnet = component["controlnet"] + if torch_dtype is not None: + controlnet = controlnet.to(torch_dtype) + + return controlnet diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index d23b2b9e87..d747bfacde 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -11,32 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -import re from huggingface_hub.utils import validate_hf_hub_args -from transformers import AutoFeatureExtractor -from ..models.modeling_utils import load_state_dict -from ..utils import ( - logging, -) -from ..utils.hub_utils import _get_model_file +from ..utils import logging from .single_file_utils import ( - create_diffusers_controlnet_model_from_ldm, create_diffusers_unet_model_from_ldm, create_diffusers_vae_model_from_ldm, create_scheduler_from_ldm, create_text_encoders_and_tokenizers_from_ldm, - fetch_original_config, + fetch_ldm_config_and_checkpoint, infer_model_type, ) logger = logging.get_logger(__name__) - -VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"] # Pipelines that support the SDXL Refiner checkpoint REFINER_PIPELINES = [ "StableDiffusionXLImg2ImgPipeline", @@ -45,29 +35,12 @@ REFINER_PIPELINES = [ ] -def _extract_repo_id_and_weights_name(pretrained_model_name_or_path): - pattern = r"([^/]+)/([^/]+)/(?:blob/main/)?(.+)" - weights_name = None - repo_id = (None,) - for prefix in VALID_URL_PREFIXES: - pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "") - match = re.match(pattern, pretrained_model_name_or_path) - if not match: - return repo_id, weights_name - - repo_id = f"{match.group(1)}/{match.group(2)}" - weights_name = match.group(3) - - return repo_id, weights_name - - def build_sub_model_components( pipeline_components, pipeline_class_name, component_name, original_config, checkpoint, - checkpoint_path_or_dict, local_files_only=False, load_safety_checker=False, **kwargs, @@ -117,6 +90,8 @@ def build_sub_model_components( if component_name == "safety_checker": if load_safety_checker: + from transformers import AutoFeatureExtractor + from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker safety_checker = StableDiffusionSafetyChecker.from_pretrained( @@ -233,50 +208,20 @@ class FromSingleFileMixin: use_safetensors = kwargs.pop("use_safetensors", True) class_name = cls.__name__ - file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] - from_safetensors = file_extension == "safetensors" - if from_safetensors and use_safetensors is False: - raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.") - - if os.path.isfile(pretrained_model_link_or_path): - checkpoint = load_state_dict(pretrained_model_link_or_path) - else: - repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path) - checkpoint_path = _get_model_file( - repo_id, - weights_name=weights_name, - force_download=force_download, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - ) - checkpoint = load_state_dict(checkpoint_path) - - # some checkpoints contain the model state dict under a "state_dict" key - while "state_dict" in checkpoint: - checkpoint = checkpoint["state_dict"] - - original_config = fetch_original_config(class_name, checkpoint, original_config_file) - - if class_name == "AutoencoderKL": - image_size = kwargs.pop("image_size", None) - component = create_diffusers_vae_model_from_ldm( - class_name, original_config, checkpoint, image_size=image_size - ) - return component["vae"] - - if class_name == "ControlNetModel": - upcast_attention = kwargs.pop("upcast_attention", False) - image_size = kwargs.pop("image_size", None) - - component = create_diffusers_controlnet_model_from_ldm( - class_name, original_config, checkpoint, upcast_attention=upcast_attention, image_size=image_size - ) - return component["controlnet"] + original_config, checkpoint = fetch_ldm_config_and_checkpoint( + pretrained_model_link_or_path=pretrained_model_link_or_path, + class_name=class_name, + original_config_file=original_config_file, + resume_download=resume_download, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + local_files_only=local_files_only, + use_safetensors=use_safetensors, + cache_dir=cache_dir, + ) from ..pipelines.pipeline_utils import _get_pipeline_class diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 475b7d3819..386ec0bd46 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -15,20 +15,15 @@ """ Conversion script for the Stable Diffusion checkpoints.""" import os +import re from contextlib import nullcontext from io import BytesIO from urllib.parse import urlparse import requests import yaml -from transformers import ( - CLIPTextConfig, - CLIPTextModel, - CLIPTextModelWithProjection, - CLIPTokenizer, -) -from ..models import UNet2DConditionModel +from ..models.modeling_utils import load_state_dict from ..schedulers import ( DDIMScheduler, DDPMScheduler, @@ -39,9 +34,18 @@ from ..schedulers import ( LMSDiscreteScheduler, PNDMScheduler, ) -from ..utils import is_accelerate_available, logging +from ..utils import is_accelerate_available, is_transformers_available, logging +from ..utils.hub_utils import _get_model_file +if is_transformers_available(): + from transformers import ( + CLIPTextConfig, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + ) + if is_accelerate_available(): from accelerate import init_empty_weights from accelerate.utils import set_module_tensor_to_device @@ -187,6 +191,71 @@ SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [ ] +VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"] + + +def _extract_repo_id_and_weights_name(pretrained_model_name_or_path): + pattern = r"([^/]+)/([^/]+)/(?:blob/main/)?(.+)" + weights_name = None + repo_id = (None,) + for prefix in VALID_URL_PREFIXES: + pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "") + match = re.match(pattern, pretrained_model_name_or_path) + if not match: + return repo_id, weights_name + + repo_id = f"{match.group(1)}/{match.group(2)}" + weights_name = match.group(3) + + return repo_id, weights_name + + +def fetch_ldm_config_and_checkpoint( + pretrained_model_link_or_path, + class_name, + original_config_file=None, + resume_download=False, + force_download=False, + proxies=None, + token=None, + cache_dir=None, + local_files_only=None, + revision=None, + use_safetensors=True, +): + file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] + from_safetensors = file_extension == "safetensors" + + if from_safetensors and use_safetensors is False: + raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.") + + if os.path.isfile(pretrained_model_link_or_path): + checkpoint = load_state_dict(pretrained_model_link_or_path) + + else: + repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path) + checkpoint_path = _get_model_file( + repo_id, + weights_name=weights_name, + force_download=force_download, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + ) + checkpoint = load_state_dict(checkpoint_path) + + # some checkpoints contain the model state dict under a "state_dict" key + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + original_config = fetch_original_config(class_name, checkpoint, original_config_file) + + return original_config, checkpoint + + def infer_original_config_file(class_name, checkpoint): if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024: config_url = CONFIG_URLS["v2"] @@ -1029,6 +1098,8 @@ def create_diffusers_unet_model_from_ldm( extract_ema=False, image_size=None, ): + from ..models import UNet2DConditionModel + if num_in_channels is None: if pipeline_class_name in [ "StableDiffusionInpaintPipeline", diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 92d12a220f..10a3ae58de 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -17,7 +17,7 @@ import torch import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import FromSingleFileMixin +from ...loaders import FromOriginalVAEMixin from ...utils.accelerate_utils import apply_forward_hook from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -32,7 +32,7 @@ from ..modeling_utils import ModelMixin from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder -class AutoencoderKL(ModelMixin, ConfigMixin, FromSingleFileMixin): +class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 8af13a6ec7..1102f4f9d3 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -19,7 +19,7 @@ from torch import nn from torch.nn import functional as F from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import FromSingleFileMixin +from ..loaders import FromOriginalControlnetMixin from ..utils import BaseOutput, logging from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -102,7 +102,7 @@ class ControlNetConditioningEmbedding(nn.Module): return embedding -class ControlNetModel(ModelMixin, ConfigMixin, FromSingleFileMixin): +class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): """ A ControlNet model.