mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Refactor] Update from single file (#6428)
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update'
* update
* update
* update
* update
* update
* update
* up
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* up
* update
* update
* update
* update
* update'
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* update
* clean
* update
* update
* clean up
* clean up
* update
* clean
* clean
* update
* updaet
* clean up
* fix docs
* update
* update
* Revert "update"
This reverts commit dbfb8f1ea9.
* update
* update
* update
* update
* fix controlnet
* fix scheduler
* fix controlnet tests
This commit is contained in:
@@ -30,8 +30,8 @@ To learn more about how to load single file weights, see the [Load different Sta
|
||||
|
||||
## FromOriginalVAEMixin
|
||||
|
||||
[[autodoc]] loaders.single_file.FromOriginalVAEMixin
|
||||
[[autodoc]] loaders.autoencoder.FromOriginalVAEMixin
|
||||
|
||||
## FromOriginalControlnetMixin
|
||||
|
||||
[[autodoc]] loaders.single_file.FromOriginalControlnetMixin
|
||||
[[autodoc]] loaders.controlnet.FromOriginalControlNetMixin
|
||||
@@ -54,12 +54,13 @@ if is_transformers_available():
|
||||
_import_structure = {}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["single_file"] = ["FromOriginalControlnetMixin", "FromOriginalVAEMixin"]
|
||||
_import_structure["autoencoder"] = ["FromOriginalVAEMixin"]
|
||||
|
||||
_import_structure["controlnet"] = ["FromOriginalControlNetMixin"]
|
||||
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
|
||||
_import_structure["utils"] = ["AttnProcsLayers"]
|
||||
|
||||
if is_transformers_available():
|
||||
_import_structure["single_file"].extend(["FromSingleFileMixin"])
|
||||
_import_structure["single_file"] = ["FromSingleFileMixin"]
|
||||
_import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin"]
|
||||
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
|
||||
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
|
||||
@@ -69,7 +70,8 @@ _import_structure["peft"] = ["PeftAdapterMixin"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if is_torch_available():
|
||||
from .single_file import FromOriginalControlnetMixin, FromOriginalVAEMixin
|
||||
from .autoencoder import FromOriginalVAEMixin
|
||||
from .controlnet import FromOriginalControlNetMixin
|
||||
from .unet import UNet2DConditionLoadersMixin
|
||||
from .utils import AttnProcsLayers
|
||||
|
||||
|
||||
126
src/diffusers/loaders/autoencoder.py
Normal file
126
src/diffusers/loaders/autoencoder.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from .single_file_utils import (
|
||||
create_diffusers_vae_model_from_ldm,
|
||||
fetch_ldm_config_and_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
class FromOriginalVAEMixin:
|
||||
"""
|
||||
Load pretrained AutoencoderKL weights saved in the `.ckpt` or `.safetensors` format into a [`AutoencoderKL`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or
|
||||
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you're loading
|
||||
a VAE from SDXL or a Stable Diffusion v2 model or higher.
|
||||
|
||||
</Tip>
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
|
||||
model = AutoencoderKL.from_single_file(url)
|
||||
```
|
||||
"""
|
||||
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", True)
|
||||
|
||||
class_name = cls.__name__
|
||||
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
|
||||
pretrained_model_link_or_path=pretrained_model_link_or_path,
|
||||
class_name=class_name,
|
||||
original_config_file=original_config_file,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
use_safetensors=use_safetensors,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
component = create_diffusers_vae_model_from_ldm(class_name, original_config, checkpoint, image_size=image_size)
|
||||
vae = component["vae"]
|
||||
if torch_dtype is not None:
|
||||
vae = vae.to(torch_dtype)
|
||||
|
||||
return vae
|
||||
127
src/diffusers/loaders/controlnet.py
Normal file
127
src/diffusers/loaders/controlnet.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from .single_file_utils import (
|
||||
create_diffusers_controlnet_model_from_ldm,
|
||||
fetch_ldm_config_and_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
class FromOriginalControlNetMixin:
|
||||
"""
|
||||
Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or
|
||||
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
upcast_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether the attention computation should always be upcasted.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
|
||||
|
||||
url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
|
||||
model = ControlNetModel.from_single_file(url)
|
||||
|
||||
url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
|
||||
pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet)
|
||||
```
|
||||
"""
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", True)
|
||||
|
||||
class_name = cls.__name__
|
||||
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
|
||||
pretrained_model_link_or_path=pretrained_model_link_or_path,
|
||||
class_name=class_name,
|
||||
original_config_file=original_config_file,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
use_safetensors=use_safetensors,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
upcast_attention = kwargs.pop("upcast_attention", False)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
|
||||
component = create_diffusers_controlnet_model_from_ldm(
|
||||
class_name, original_config, checkpoint, upcast_attention=upcast_attention, image_size=image_size
|
||||
)
|
||||
controlnet = component["controlnet"]
|
||||
if torch_dtype is not None:
|
||||
controlnet = controlnet.to(torch_dtype)
|
||||
|
||||
return controlnet
|
||||
@@ -11,39 +11,132 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import nullcontext
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import yaml
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..utils import deprecate, is_accelerate_available, is_transformers_available, logging
|
||||
from ..utils import is_transformers_available, logging
|
||||
from .single_file_utils import (
|
||||
create_diffusers_unet_model_from_ldm,
|
||||
create_diffusers_vae_model_from_ldm,
|
||||
create_scheduler_from_ldm,
|
||||
create_text_encoders_and_tokenizers_from_ldm,
|
||||
fetch_ldm_config_and_checkpoint,
|
||||
infer_model_type,
|
||||
)
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
pass
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Pipelines that support the SDXL Refiner checkpoint
|
||||
REFINER_PIPELINES = [
|
||||
"StableDiffusionXLImg2ImgPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLControlNetImg2ImgPipeline",
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
|
||||
def build_sub_model_components(
|
||||
pipeline_components,
|
||||
pipeline_class_name,
|
||||
component_name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
local_files_only=False,
|
||||
load_safety_checker=False,
|
||||
model_type=None,
|
||||
image_size=None,
|
||||
**kwargs,
|
||||
):
|
||||
if component_name in pipeline_components:
|
||||
return {}
|
||||
|
||||
if component_name == "unet":
|
||||
num_in_channels = kwargs.pop("num_in_channels", None)
|
||||
unet_components = create_diffusers_unet_model_from_ldm(
|
||||
pipeline_class_name, original_config, checkpoint, num_in_channels=num_in_channels, image_size=image_size
|
||||
)
|
||||
return unet_components
|
||||
|
||||
if component_name == "vae":
|
||||
vae_components = create_diffusers_vae_model_from_ldm(
|
||||
pipeline_class_name, original_config, checkpoint, image_size
|
||||
)
|
||||
return vae_components
|
||||
|
||||
if component_name == "scheduler":
|
||||
scheduler_type = kwargs.get("scheduler_type", "ddim")
|
||||
prediction_type = kwargs.get("prediction_type", None)
|
||||
|
||||
scheduler_components = create_scheduler_from_ldm(
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
scheduler_type=scheduler_type,
|
||||
prediction_type=prediction_type,
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
return scheduler_components
|
||||
|
||||
if component_name in ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]:
|
||||
text_encoder_components = create_text_encoders_and_tokenizers_from_ldm(
|
||||
original_config,
|
||||
checkpoint,
|
||||
model_type=model_type,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
return text_encoder_components
|
||||
|
||||
if component_name == "safety_checker":
|
||||
if load_safety_checker:
|
||||
from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
|
||||
)
|
||||
else:
|
||||
safety_checker = None
|
||||
return {"safety_checker": safety_checker}
|
||||
|
||||
if component_name == "feature_extractor":
|
||||
if load_safety_checker:
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
|
||||
)
|
||||
else:
|
||||
feature_extractor = None
|
||||
return {"feature_extractor": feature_extractor}
|
||||
|
||||
return
|
||||
|
||||
|
||||
def set_additional_components(
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
model_type=None,
|
||||
):
|
||||
components = {}
|
||||
if pipeline_class_name in REFINER_PIPELINES:
|
||||
model_type = infer_model_type(original_config, model_type=model_type)
|
||||
is_refiner = model_type == "SDXL-Refiner"
|
||||
components.update(
|
||||
{
|
||||
"requires_aesthetics_score": is_refiner,
|
||||
"force_zeros_for_empty_prompt": False if is_refiner else True,
|
||||
}
|
||||
)
|
||||
|
||||
return components
|
||||
|
||||
|
||||
class FromSingleFileMixin:
|
||||
"""
|
||||
Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_ckpt(cls, *args, **kwargs):
|
||||
deprecation_message = "The function `from_ckpt` is deprecated in favor of `from_single_file` and will be removed in diffusers v.0.21. Please make sure to use `StableDiffusionPipeline.from_single_file(...)` instead."
|
||||
deprecate("from_ckpt", "0.21.0", deprecation_message, standard_warn=False)
|
||||
return cls.from_single_file(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
@@ -58,8 +151,7 @@ class FromSingleFileMixin:
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
Override the default `torch.dtype` and load the model with another dtype.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
@@ -85,42 +177,6 @@ class FromSingleFileMixin:
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
extract_ema (`bool`, *optional*, defaults to `False`):
|
||||
Whether to extract the EMA weights or not. Pass `True` to extract the EMA weights which usually yield
|
||||
higher quality images for inference. Non-EMA weights are usually better for continuing finetuning.
|
||||
upcast_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether the attention computation should always be upcasted.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
prediction_type (`str`, *optional*):
|
||||
The prediction type the model was trained on. Use `'epsilon'` for all Stable Diffusion v1 models and
|
||||
the Stable Diffusion v2 base model. Use `'v_prediction'` for Stable Diffusion v2.
|
||||
num_in_channels (`int`, *optional*, defaults to `None`):
|
||||
The number of input channels. If `None`, it is automatically inferred.
|
||||
scheduler_type (`str`, *optional*, defaults to `"pndm"`):
|
||||
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
|
||||
"ddim"]`.
|
||||
load_safety_checker (`bool`, *optional*, defaults to `True`):
|
||||
Whether to load the safety checker or not.
|
||||
text_encoder ([`~transformers.CLIPTextModel`], *optional*, defaults to `None`):
|
||||
An instance of `CLIPTextModel` to use, specifically the
|
||||
[clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. If this
|
||||
parameter is `None`, the function loads a new instance of `CLIPTextModel` by itself if needed.
|
||||
vae (`AutoencoderKL`, *optional*, defaults to `None`):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. If
|
||||
this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed.
|
||||
tokenizer ([`~transformers.CLIPTokenizer`], *optional*, defaults to `None`):
|
||||
An instance of `CLIPTokenizer` to use. If this parameter is `None`, the function loads a new instance
|
||||
of `CLIPTokenizer` by itself if needed.
|
||||
original_config_file (`str`):
|
||||
Path to `.yaml` config file corresponding to the original architecture. If `None`, will be
|
||||
automatically inferred by looking for a key that only exists in SD2.0 models.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
@@ -143,484 +199,80 @@ class FromSingleFileMixin:
|
||||
>>> pipeline.to("cuda")
|
||||
```
|
||||
"""
|
||||
# import here to avoid circular dependency
|
||||
from ..pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
|
||||
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
config_files = kwargs.pop("config_files", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
extract_ema = kwargs.pop("extract_ema", False)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
scheduler_type = kwargs.pop("scheduler_type", "pndm")
|
||||
num_in_channels = kwargs.pop("num_in_channels", None)
|
||||
upcast_attention = kwargs.pop("upcast_attention", None)
|
||||
load_safety_checker = kwargs.pop("load_safety_checker", True)
|
||||
prediction_type = kwargs.pop("prediction_type", None)
|
||||
text_encoder = kwargs.pop("text_encoder", None)
|
||||
text_encoder_2 = kwargs.pop("text_encoder_2", None)
|
||||
vae = kwargs.pop("vae", None)
|
||||
controlnet = kwargs.pop("controlnet", None)
|
||||
adapter = kwargs.pop("adapter", None)
|
||||
tokenizer = kwargs.pop("tokenizer", None)
|
||||
tokenizer_2 = kwargs.pop("tokenizer_2", None)
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", True)
|
||||
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
class_name = cls.__name__
|
||||
|
||||
pipeline_name = cls.__name__
|
||||
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
||||
from_safetensors = file_extension == "safetensors"
|
||||
|
||||
if from_safetensors and use_safetensors is False:
|
||||
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
|
||||
|
||||
# TODO: For now we only support stable diffusion
|
||||
stable_unclip = None
|
||||
model_type = None
|
||||
|
||||
if pipeline_name in [
|
||||
"StableDiffusionControlNetPipeline",
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
"StableDiffusionControlNetInpaintPipeline",
|
||||
]:
|
||||
from ..models.controlnet import ControlNetModel
|
||||
from ..pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||
|
||||
# list/tuple or a single instance of ControlNetModel or MultiControlNetModel
|
||||
if not (
|
||||
isinstance(controlnet, (ControlNetModel, MultiControlNetModel))
|
||||
or isinstance(controlnet, (list, tuple))
|
||||
and isinstance(controlnet[0], ControlNetModel)
|
||||
):
|
||||
raise ValueError("ControlNet needs to be passed if loading from ControlNet pipeline.")
|
||||
elif "StableDiffusion" in pipeline_name:
|
||||
# Model type will be inferred from the checkpoint.
|
||||
pass
|
||||
elif pipeline_name == "StableUnCLIPPipeline":
|
||||
model_type = "FrozenOpenCLIPEmbedder"
|
||||
stable_unclip = "txt2img"
|
||||
elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
|
||||
model_type = "FrozenOpenCLIPEmbedder"
|
||||
stable_unclip = "img2img"
|
||||
elif pipeline_name == "PaintByExamplePipeline":
|
||||
model_type = "PaintByExample"
|
||||
elif pipeline_name == "LDMTextToImagePipeline":
|
||||
model_type = "LDMTextToImage"
|
||||
else:
|
||||
raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
|
||||
|
||||
# remove huggingface url
|
||||
has_valid_url_prefix = False
|
||||
valid_url_prefixes = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
|
||||
for prefix in valid_url_prefixes:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
has_valid_url_prefix = True
|
||||
|
||||
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
||||
ckpt_path = Path(pretrained_model_link_or_path)
|
||||
if not ckpt_path.is_file():
|
||||
if not has_valid_url_prefix:
|
||||
raise ValueError(
|
||||
f"The provided path is either not a file or a valid huggingface URL was not provided. Valid URLs begin with {', '.join(valid_url_prefixes)}"
|
||||
)
|
||||
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = "/".join(ckpt_path.parts[:2])
|
||||
file_path = "/".join(ckpt_path.parts[2:])
|
||||
|
||||
if file_path.startswith("blob/"):
|
||||
file_path = file_path[len("blob/") :]
|
||||
|
||||
if file_path.startswith("main/"):
|
||||
file_path = file_path[len("main/") :]
|
||||
|
||||
pretrained_model_link_or_path = hf_hub_download(
|
||||
repo_id,
|
||||
filename=file_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
pretrained_model_link_or_path,
|
||||
pipeline_class=cls,
|
||||
model_type=model_type,
|
||||
stable_unclip=stable_unclip,
|
||||
controlnet=controlnet,
|
||||
adapter=adapter,
|
||||
from_safetensors=from_safetensors,
|
||||
extract_ema=extract_ema,
|
||||
image_size=image_size,
|
||||
scheduler_type=scheduler_type,
|
||||
num_in_channels=num_in_channels,
|
||||
upcast_attention=upcast_attention,
|
||||
load_safety_checker=load_safety_checker,
|
||||
prediction_type=prediction_type,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
vae=vae,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
|
||||
pretrained_model_link_or_path=pretrained_model_link_or_path,
|
||||
class_name=class_name,
|
||||
original_config_file=original_config_file,
|
||||
config_files=config_files,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
use_safetensors=use_safetensors,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
from ..pipelines.pipeline_utils import _get_pipeline_class
|
||||
|
||||
pipeline_class = _get_pipeline_class(
|
||||
cls,
|
||||
config=None,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
||||
|
||||
model_type = kwargs.pop("model_type", None)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
load_safety_checker = (kwargs.pop("load_safety_checker", False)) or (
|
||||
passed_class_obj.get("safety_checker", None) is not None
|
||||
)
|
||||
|
||||
init_kwargs = {}
|
||||
for name in expected_modules:
|
||||
if name in passed_class_obj:
|
||||
init_kwargs[name] = passed_class_obj[name]
|
||||
else:
|
||||
components = build_sub_model_components(
|
||||
init_kwargs,
|
||||
class_name,
|
||||
name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
model_type=model_type,
|
||||
image_size=image_size,
|
||||
load_safety_checker=load_safety_checker,
|
||||
local_files_only=local_files_only,
|
||||
**kwargs,
|
||||
)
|
||||
if not components:
|
||||
continue
|
||||
init_kwargs.update(components)
|
||||
|
||||
additional_components = set_additional_components(class_name, original_config, model_type=model_type)
|
||||
if additional_components:
|
||||
init_kwargs.update(additional_components)
|
||||
|
||||
init_kwargs.update(passed_pipe_kwargs)
|
||||
pipe = pipeline_class(**init_kwargs)
|
||||
|
||||
if torch_dtype is not None:
|
||||
pipe.to(dtype=torch_dtype)
|
||||
|
||||
return pipe
|
||||
|
||||
|
||||
class FromOriginalVAEMixin:
|
||||
"""
|
||||
Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into an [`AutoencoderKL`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or
|
||||
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
upcast_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether the attention computation should always be upcasted.
|
||||
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
||||
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
||||
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
||||
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
||||
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z
|
||||
= 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution
|
||||
Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you're loading
|
||||
a VAE from SDXL or a Stable Diffusion v2 model or higher.
|
||||
|
||||
</Tip>
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
|
||||
model = AutoencoderKL.from_single_file(url)
|
||||
```
|
||||
"""
|
||||
from ..models import AutoencoderKL
|
||||
|
||||
# import here to avoid circular dependency
|
||||
from ..pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
convert_ldm_vae_checkpoint,
|
||||
create_vae_diffusers_config,
|
||||
)
|
||||
|
||||
config_file = kwargs.pop("config_file", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
scaling_factor = kwargs.pop("scaling_factor", None)
|
||||
kwargs.pop("upcast_attention", None)
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
||||
from_safetensors = file_extension == "safetensors"
|
||||
|
||||
if from_safetensors and use_safetensors is False:
|
||||
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
|
||||
|
||||
# remove huggingface url
|
||||
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
|
||||
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
||||
ckpt_path = Path(pretrained_model_link_or_path)
|
||||
if not ckpt_path.is_file():
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = "/".join(ckpt_path.parts[:2])
|
||||
file_path = "/".join(ckpt_path.parts[2:])
|
||||
|
||||
if file_path.startswith("blob/"):
|
||||
file_path = file_path[len("blob/") :]
|
||||
|
||||
if file_path.startswith("main/"):
|
||||
file_path = file_path[len("main/") :]
|
||||
|
||||
pretrained_model_link_or_path = hf_hub_download(
|
||||
repo_id,
|
||||
filename=file_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
|
||||
if from_safetensors:
|
||||
from safetensors import safe_open
|
||||
|
||||
checkpoint = {}
|
||||
with safe_open(pretrained_model_link_or_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
checkpoint[key] = f.get_tensor(key)
|
||||
else:
|
||||
checkpoint = torch.load(pretrained_model_link_or_path, map_location="cpu")
|
||||
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
if config_file is None:
|
||||
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
config_file = BytesIO(requests.get(config_url).content)
|
||||
|
||||
original_config = yaml.safe_load(config_file)
|
||||
|
||||
# default to sd-v1-5
|
||||
image_size = image_size or 512
|
||||
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
|
||||
if scaling_factor is None:
|
||||
if (
|
||||
"model" in original_config
|
||||
and "params" in original_config["model"]
|
||||
and "scale_factor" in original_config["model"]["params"]
|
||||
):
|
||||
vae_scaling_factor = original_config["model"]["params"]["scale_factor"]
|
||||
else:
|
||||
vae_scaling_factor = 0.18215 # default SD scaling factor
|
||||
|
||||
vae_config["scaling_factor"] = vae_scaling_factor
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
|
||||
if is_accelerate_available():
|
||||
from ..models.modeling_utils import load_model_dict_into_meta
|
||||
|
||||
load_model_dict_into_meta(vae, converted_vae_checkpoint, device="cpu")
|
||||
else:
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
|
||||
if torch_dtype is not None:
|
||||
vae.to(dtype=torch_dtype)
|
||||
|
||||
return vae
|
||||
|
||||
|
||||
class FromOriginalControlnetMixin:
|
||||
"""
|
||||
Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or
|
||||
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
upcast_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether the attention computation should always be upcasted.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
|
||||
|
||||
url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
|
||||
model = ControlNetModel.from_single_file(url)
|
||||
|
||||
url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
|
||||
pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet)
|
||||
```
|
||||
"""
|
||||
# import here to avoid circular dependency
|
||||
from ..pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt
|
||||
|
||||
config_file = kwargs.pop("config_file", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
num_in_channels = kwargs.pop("num_in_channels", None)
|
||||
use_linear_projection = kwargs.pop("use_linear_projection", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
extract_ema = kwargs.pop("extract_ema", False)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
upcast_attention = kwargs.pop("upcast_attention", None)
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
||||
from_safetensors = file_extension == "safetensors"
|
||||
|
||||
if from_safetensors and use_safetensors is False:
|
||||
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
|
||||
|
||||
# remove huggingface url
|
||||
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
|
||||
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
||||
ckpt_path = Path(pretrained_model_link_or_path)
|
||||
if not ckpt_path.is_file():
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = "/".join(ckpt_path.parts[:2])
|
||||
file_path = "/".join(ckpt_path.parts[2:])
|
||||
|
||||
if file_path.startswith("blob/"):
|
||||
file_path = file_path[len("blob/") :]
|
||||
|
||||
if file_path.startswith("main/"):
|
||||
file_path = file_path[len("main/") :]
|
||||
|
||||
pretrained_model_link_or_path = hf_hub_download(
|
||||
repo_id,
|
||||
filename=file_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
|
||||
if config_file is None:
|
||||
config_url = "https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml"
|
||||
config_file = BytesIO(requests.get(config_url).content)
|
||||
|
||||
image_size = image_size or 512
|
||||
|
||||
controlnet = download_controlnet_from_original_ckpt(
|
||||
pretrained_model_link_or_path,
|
||||
original_config_file=config_file,
|
||||
image_size=image_size,
|
||||
extract_ema=extract_ema,
|
||||
num_in_channels=num_in_channels,
|
||||
upcast_attention=upcast_attention,
|
||||
from_safetensors=from_safetensors,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
|
||||
if torch_dtype is not None:
|
||||
controlnet.to(dtype=torch_dtype)
|
||||
|
||||
return controlnet
|
||||
|
||||
1387
src/diffusers/loaders/single_file_utils.py
Normal file
1387
src/diffusers/loaders/single_file_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -17,7 +17,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalVAEMixin
|
||||
from ...utils import is_torch_version
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
|
||||
@@ -162,7 +161,7 @@ class TemporalDecoder(nn.Module):
|
||||
return sample
|
||||
|
||||
|
||||
class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import FromOriginalControlnetMixin
|
||||
from ..loaders import FromOriginalControlNetMixin
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
@@ -108,7 +108,7 @@ class ControlNetConditioningEmbedding(nn.Module):
|
||||
return embedding
|
||||
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
||||
"""
|
||||
A ControlNet model.
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ from .. import __version__
|
||||
from ..utils import (
|
||||
CONFIG_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
SAFETENSORS_FILE_EXTENSION,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
_add_variant,
|
||||
@@ -102,10 +103,11 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
|
||||
Reads a checkpoint file, returning properly formatted errors if they arise.
|
||||
"""
|
||||
try:
|
||||
if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
|
||||
return torch.load(checkpoint_file, map_location="cpu")
|
||||
else:
|
||||
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
||||
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
||||
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||
else:
|
||||
return torch.load(checkpoint_file, map_location="cpu")
|
||||
except Exception as e:
|
||||
try:
|
||||
with open(checkpoint_file) as f:
|
||||
|
||||
@@ -351,7 +351,7 @@ def get_class_obj_and_candidates(
|
||||
|
||||
def _get_pipeline_class(
|
||||
class_obj,
|
||||
config,
|
||||
config=None,
|
||||
load_connected_pipeline=False,
|
||||
custom_pipeline=None,
|
||||
repo_id=None,
|
||||
@@ -389,7 +389,12 @@ def _get_pipeline_class(
|
||||
return class_obj
|
||||
|
||||
diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
|
||||
class_name = config["_class_name"]
|
||||
class_name = class_name or config["_class_name"]
|
||||
if not class_name:
|
||||
raise ValueError(
|
||||
"The class name could not be found in the configuration file. Please make sure to pass the correct `class_name`."
|
||||
)
|
||||
|
||||
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
|
||||
|
||||
pipeline_cls = getattr(diffusers_module, class_name)
|
||||
|
||||
@@ -28,6 +28,7 @@ from .constants import (
|
||||
MIN_PEFT_VERSION,
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME,
|
||||
ONNX_WEIGHTS_NAME,
|
||||
SAFETENSORS_FILE_EXTENSION,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
USE_PEFT_BACKEND,
|
||||
WEIGHTS_NAME,
|
||||
|
||||
@@ -31,6 +31,7 @@ WEIGHTS_NAME = "diffusion_pytorch_model.bin"
|
||||
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
|
||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
|
||||
SAFETENSORS_FILE_EXTENSION = "safetensors"
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||
|
||||
@@ -244,15 +244,15 @@ def _get_model_file(
|
||||
pretrained_model_name_or_path: Union[str, Path],
|
||||
*,
|
||||
weights_name: str,
|
||||
subfolder: Optional[str],
|
||||
cache_dir: Optional[str],
|
||||
force_download: bool,
|
||||
proxies: Optional[Dict],
|
||||
resume_download: bool,
|
||||
local_files_only: bool,
|
||||
token: Optional[str],
|
||||
user_agent: Union[Dict, str, None],
|
||||
revision: Optional[str],
|
||||
subfolder: Optional[str] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
force_download: bool = False,
|
||||
proxies: Optional[Dict] = None,
|
||||
resume_download: bool = False,
|
||||
local_files_only: bool = False,
|
||||
token: Optional[str] = None,
|
||||
user_agent: Optional[Union[Dict, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
commit_hash: Optional[str] = None,
|
||||
):
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
|
||||
@@ -37,6 +37,7 @@ from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
load_image,
|
||||
load_numpy,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_python39_or_higher,
|
||||
require_torch_2,
|
||||
require_torch_gpu,
|
||||
@@ -1022,39 +1023,49 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
|
||||
|
||||
def test_load_local(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
|
||||
pipe_1 = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
controlnet = ControlNetModel.from_single_file(
|
||||
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
|
||||
)
|
||||
pipe_2 = StableDiffusionControlNetPipeline.from_single_file(
|
||||
pipe_sf = StableDiffusionControlNetPipeline.from_single_file(
|
||||
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
|
||||
safety_checker=None,
|
||||
controlnet=controlnet,
|
||||
scheduler_type="pndm",
|
||||
)
|
||||
pipes = [pipe_1, pipe_2]
|
||||
images = []
|
||||
pipe_sf.unet.set_default_attn_processor()
|
||||
pipe_sf.enable_model_cpu_offload()
|
||||
|
||||
for pipe in pipes:
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
).resize((512, 512))
|
||||
prompt = "bird"
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "bird"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
)
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
output = pipe(
|
||||
prompt,
|
||||
image=control_image,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
num_inference_steps=3,
|
||||
).images[0]
|
||||
|
||||
output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3)
|
||||
images.append(output.images[0])
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
output_sf = pipe_sf(
|
||||
prompt,
|
||||
image=control_image,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
num_inference_steps=3,
|
||||
).images[0]
|
||||
|
||||
del pipe
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
assert np.abs(images[0] - images[1]).max() < 1e-3
|
||||
max_diff = numpy_cosine_similarity_distance(output_sf.flatten(), output.flatten())
|
||||
assert max_diff < 1e-3
|
||||
|
||||
|
||||
@slow
|
||||
|
||||
@@ -39,6 +39,7 @@ from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_numpy,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -421,46 +422,53 @@ class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase):
|
||||
|
||||
def test_load_local(self):
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
|
||||
pipe_1 = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
||||
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
controlnet = ControlNetModel.from_single_file(
|
||||
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
|
||||
)
|
||||
pipe_2 = StableDiffusionControlNetImg2ImgPipeline.from_single_file(
|
||||
pipe_sf = StableDiffusionControlNetImg2ImgPipeline.from_single_file(
|
||||
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
|
||||
safety_checker=None,
|
||||
controlnet=controlnet,
|
||||
scheduler_type="pndm",
|
||||
)
|
||||
pipe_sf.unet.set_default_attn_processor()
|
||||
pipe_sf.enable_model_cpu_offload()
|
||||
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
).resize((512, 512))
|
||||
image = load_image(
|
||||
"https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png"
|
||||
).resize((512, 512))
|
||||
prompt = "bird"
|
||||
|
||||
pipes = [pipe_1, pipe_2]
|
||||
images = []
|
||||
for pipe in pipes:
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
output = pipe(
|
||||
prompt,
|
||||
image=image,
|
||||
control_image=control_image,
|
||||
strength=0.9,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
num_inference_steps=3,
|
||||
).images[0]
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "bird"
|
||||
output = pipe(
|
||||
prompt,
|
||||
image=image,
|
||||
control_image=control_image,
|
||||
strength=0.9,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
num_inference_steps=3,
|
||||
)
|
||||
images.append(output.images[0])
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
output_sf = pipe_sf(
|
||||
prompt,
|
||||
image=image,
|
||||
control_image=control_image,
|
||||
strength=0.9,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
num_inference_steps=3,
|
||||
).images[0]
|
||||
|
||||
del pipe
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
assert np.abs(images[0] - images[1]).max() < 1e-3
|
||||
max_diff = numpy_cosine_similarity_distance(output_sf.flatten(), output.flatten())
|
||||
assert max_diff < 1e-3
|
||||
|
||||
@@ -569,6 +569,7 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase):
|
||||
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
|
||||
safety_checker=None,
|
||||
controlnet=controlnet,
|
||||
scheduler_type="pndm",
|
||||
)
|
||||
control_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
@@ -605,4 +606,5 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
assert np.abs(images[0] - images[1]).max() < 1e-3
|
||||
max_diff = numpy_cosine_similarity_distance(images[0].flatten(), images[1].flatten())
|
||||
assert max_diff < 1e-3
|
||||
|
||||
@@ -31,7 +31,14 @@ from diffusers import (
|
||||
from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D
|
||||
from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, load_image, require_torch_gpu, slow, torch_device
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
load_image,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ..pipeline_params import (
|
||||
@@ -819,6 +826,41 @@ class ControlNetSDXLPipelineSlowTests(unittest.TestCase):
|
||||
expected_image = np.array([0.4399, 0.5112, 0.5478, 0.4314, 0.472, 0.4823, 0.4647, 0.4957, 0.4853])
|
||||
assert np.allclose(original_image, expected_image, atol=1e-04)
|
||||
|
||||
def test_download_ckpt_diff_format_is_same(self):
|
||||
controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16)
|
||||
single_file_url = (
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
|
||||
)
|
||||
pipe_single_file = StableDiffusionXLControlNetPipeline.from_single_file(
|
||||
single_file_url, controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe_single_file.unet.set_default_attn_processor()
|
||||
pipe_single_file.enable_model_cpu_offload()
|
||||
pipe_single_file.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "Stormtrooper's lecture"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png"
|
||||
)
|
||||
single_file_images = pipe_single_file(
|
||||
prompt, image=image, generator=generator, output_type="np", num_inference_steps=2
|
||||
).images
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.enable_model_cpu_offload()
|
||||
images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=2).images
|
||||
|
||||
assert images[0].shape == (512, 512, 3)
|
||||
assert single_file_images[0].shape == (512, 512, 3)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(images[0].flatten(), single_file_images[0].flatten())
|
||||
assert max_diff < 5e-2
|
||||
|
||||
|
||||
class StableDiffusionSSD1BControlNetPipelineFastTests(StableDiffusionXLControlNetPipelineFastTests):
|
||||
def test_controlnet_sdxl_guess(self):
|
||||
|
||||
@@ -1262,13 +1262,13 @@ class StableDiffusionPipelineCkptTests(unittest.TestCase):
|
||||
def test_download_ckpt_diff_format_is_same(self):
|
||||
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt"
|
||||
|
||||
pipe = StableDiffusionPipeline.from_single_file(ckpt_path)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.unet.set_attn_processor(AttnProcessor())
|
||||
pipe.to("cuda")
|
||||
sf_pipe = StableDiffusionPipeline.from_single_file(ckpt_path)
|
||||
sf_pipe.scheduler = DDIMScheduler.from_config(sf_pipe.scheduler.config)
|
||||
sf_pipe.unet.set_attn_processor(AttnProcessor())
|
||||
sf_pipe.to("cuda")
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image_ckpt = pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0]
|
||||
image_single_file = sf_pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0]
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
@@ -1278,7 +1278,7 @@ class StableDiffusionPipelineCkptTests(unittest.TestCase):
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image = pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0]
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_ckpt.flatten())
|
||||
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten())
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ from diffusers.utils.testing_utils import (
|
||||
load_image,
|
||||
load_numpy,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_python39_or_higher,
|
||||
require_torch_2,
|
||||
require_torch_gpu,
|
||||
@@ -771,7 +772,9 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
|
||||
inputs["num_inference_steps"] = 5
|
||||
image = pipe(**inputs).images[0]
|
||||
|
||||
assert np.max(np.abs(image - image_ckpt)) < 5e-4
|
||||
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_ckpt.flatten())
|
||||
|
||||
assert max_diff < 1e-4
|
||||
|
||||
|
||||
@slow
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
@@ -1024,6 +1025,11 @@ class StableDiffusionXLPipelineFastTests(
|
||||
|
||||
@slow
|
||||
class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_stable_diffusion_lcm(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
@@ -1049,3 +1055,30 @@ class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):
|
||||
max_diff = numpy_cosine_similarity_distance(image.flatten(), expected_image.flatten())
|
||||
|
||||
assert max_diff < 1e-2
|
||||
|
||||
def test_download_ckpt_diff_format_is_same(self):
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
|
||||
)
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_single_file(ckpt_path, torch_dtype=torch.float16)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image_ckpt = pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0]
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image = pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0]
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_ckpt.flatten())
|
||||
|
||||
assert max_diff < 6e-3
|
||||
|
||||
@@ -699,3 +699,40 @@ class AdapterSDXLPipelineSlowTests(unittest.TestCase):
|
||||
image_slice = images[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.4284, 0.4337, 0.4319, 0.4255, 0.4329, 0.4280, 0.4338, 0.4420, 0.4226])
|
||||
assert numpy_cosine_similarity_distance(image_slice, expected_slice) < 1e-4
|
||||
|
||||
def test_download_ckpt_diff_format_is_same(self):
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
|
||||
)
|
||||
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16)
|
||||
prompt = "toy"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png"
|
||||
)
|
||||
pipe_single_file = StableDiffusionXLAdapterPipeline.from_single_file(
|
||||
ckpt_path,
|
||||
adapter=adapter,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe_single_file.enable_model_cpu_offload()
|
||||
pipe_single_file.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
images_single_file = pipe_single_file(
|
||||
prompt, image=image, generator=generator, output_type="np", num_inference_steps=3
|
||||
).images
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
adapter=adapter,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
|
||||
|
||||
assert images_single_file[0].shape == (768, 512, 3)
|
||||
assert images[0].shape == (768, 512, 3)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(images[0].flatten(), images_single_file[0].flatten())
|
||||
assert max_diff < 5e-3
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import random
|
||||
import unittest
|
||||
|
||||
@@ -31,15 +32,19 @@ from transformers import (
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
AutoencoderTiny,
|
||||
DDIMScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LCMScheduler,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
@@ -763,3 +768,44 @@ class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests(
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
self._test_save_load_optional_components()
|
||||
|
||||
|
||||
@slow
|
||||
class StableDiffusionXLImg2ImgIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_download_ckpt_diff_format_is_same(self):
|
||||
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/main/sd_xl_refiner_1.0.safetensors"
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_img2img/sketch-mountains-input.png"
|
||||
)
|
||||
|
||||
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16
|
||||
)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.unet.set_default_attn_processor()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image = pipe(
|
||||
prompt="mountains", image=init_image, num_inference_steps=5, generator=generator, output_type="np"
|
||||
).images[0]
|
||||
|
||||
pipe_single_file = StableDiffusionXLImg2ImgPipeline.from_single_file(ckpt_path, torch_dtype=torch.float16)
|
||||
pipe_single_file.scheduler = DDIMScheduler.from_config(pipe_single_file.scheduler.config)
|
||||
pipe_single_file.unet.set_default_attn_processor()
|
||||
pipe_single_file.enable_model_cpu_offload()
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image_single_file = pipe_single_file(
|
||||
prompt="mountains", image=init_image, num_inference_steps=5, generator=generator, output_type="np"
|
||||
).images[0]
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten())
|
||||
|
||||
assert max_diff < 5e-2
|
||||
|
||||
Reference in New Issue
Block a user