1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Refactor] Better align from_single_file logic with from_pretrained (#7496)

* refactor unet single file loading a bit.

* retrieve the unet from create_diffusers_unet_model_from_ldm

* update

* update

* updae

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* tests

* update

* update

* update

* Update docs/source/en/api/single_file.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update docs/source/en/api/single_file.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update docs/source/en/api/loaders/single_file.md

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/loaders/single_file.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update docs/source/en/api/loaders/single_file.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update docs/source/en/api/loaders/single_file.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update docs/source/en/api/loaders/single_file.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update docs/source/en/api/loaders/single_file.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

---------

Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
Dhruv Nair
2024-05-09 19:00:19 +05:30
committed by GitHub
parent caf9e985df
commit cb0f3b49cb
45 changed files with 3811 additions and 1768 deletions

View File

@@ -124,7 +124,7 @@ jobs:
shell: bash
strategy:
matrix:
module: [models, schedulers, lora, others]
module: [models, schedulers, lora, others, single_file]
steps:
- name: Checkout diffusers
uses: actions/checkout@v3

View File

@@ -10,13 +10,124 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
# Single files
# Loading Pipelines and Models via `from_single_file`
Diffusers supports loading pretrained pipeline (or model) weights stored in a single file, such as a `ckpt` or `safetensors` file. These single file types are typically produced from community trained models. There are three classes for loading single file weights:
The `from_single_file` method allows you to load supported pipelines using a single checkpoint file as opposed to the folder format used by Diffusers. This is useful if you are working with many of the Stable Diffusion Web UI's (such as A1111) that extensively rely on a single file to distribute all the components of a diffusion model.
- [`FromSingleFileMixin`] supports loading pretrained pipeline weights stored in a single file, which can either be a `ckpt` or `safetensors` file.
- [`FromOriginalVAEMixin`] supports loading a pretrained [`AutoencoderKL`] from pretrained ControlNet weights stored in a single file, which can either be a `ckpt` or `safetensors` file.
- [`FromOriginalControlnetMixin`] supports loading pretrained ControlNet weights stored in a single file, which can either be a `ckpt` or `safetensors` file.
The `from_single_file` method also supports loading models in their originally distributed format. This means that supported models that have been finetuned with other services can be loaded directly into supported Diffusers model objects and pipelines.
## Pipelines that currently support `from_single_file` loading
- [`StableDiffusionPipeline`]
- [`StableDiffusionImg2ImgPipeline`]
- [`StableDiffusionInpaintPipeline`]
- [`StableDiffusionControlNetPipeline`]
- [`StableDiffusionControlNetImg2ImgPipeline`]
- [`StableDiffusionControlNetInpaintPipeline`]
- [`StableDiffusionUpscalePipeline`]
- [`StableDiffusionXLPipeline`]
- [`StableDiffusionXLImg2ImgPipeline`]
- [`StableDiffusionXLInpaintPipeline`]
- [`StableDiffusionXLInstructPix2PixPipeline`]
- [`StableDiffusionXLControlNetPipeline`]
- [`StableDiffusionXLKDiffusionPipeline`]
- [`LatentConsistencyModelPipeline`]
- [`LatentConsistencyModelImg2ImgPipeline`]
- [`StableDiffusionControlNetXSPipeline`]
- [`StableDiffusionXLControlNetXSPipeline`]
- [`LEditsPPPipelineStableDiffusion`]
- [`LEditsPPPipelineStableDiffusionXL`]
- [`PIAPipeline`]
## Models that currently support `from_single_file` loading
- [`UNet2DConditionModel`]
- [`StableCascadeUNet`]
- [`AutoencoderKL`]
- [`ControlNetModel`]
## Usage Examples
## Loading a Pipeline using `from_single_file`
```python
from diffusers import StableDiffusionXLPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
pipe = StableDiffusionXLPipeline.from_single_file(ckpt_path)
```
## Setting components in a Pipeline using `from_single_file`
Swap components of the pipeline by passing them directly to the `from_single_file` method. e.g If you would like use a different scheduler than the pipeline default.
```python
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
scheduler = DDIMScheduler()
pipe = StableDiffusionXLPipeline.from_single_file(ckpt_path, scheduler=scheduler)
```
```python
from diffusers import StableDiffusionPipeline, ControlNetModel
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
controlnet = ControlNetModel.from_pretrained("https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors")
pipe = StableDiffusionPipeline.from_single_file(ckpt_path, controlnet=controlnet)
```
## Loading a Model using `from_single_file`
```python
from diffusers import StableCascadeUNet
ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors"
model = StableCascadeUNet.from_single_file(ckpt_path)
```
## Using a Diffusers model repository to configure single file loading
Under the hood, `from_single_file` will try to determine a model repository to use to configure the components of the pipeline. You can also pass in a repository id to the `config` argument of the `from_single_file` method to explicitly set the repository to use.
```python
from diffusers import StableDiffusionXLPipeline
ckpt_path = "https://huggingface.co/segmind/SSD-1B/blob/main/SSD-1B.safetensors"
repo_id = "segmind/SSD-1B"
pipe = StableDiffusionXLPipeline.from_single_file(ckpt_path, config=repo_id)
```
## Override configuration options when using single file loading
Override the default model or pipeline configuration options when using `from_single_file` by passing in the relevant arguments directly to the `from_single_file` method. Any argument that is supported by the model or pipeline class can be configured in this way:
```python
from diffusers import StableDiffusionXLInstructPix2PixPipeline
ckpt_path = "https://huggingface.co/stabilityai/cosxl/blob/main/cosxl_edit.safetensors"
pipe = StableDiffusionXLInstructPix2PixPipeline.from_single_file(ckpt_path, config="diffusers/sdxl-instructpix2pix-768", is_cosxl_edit=True)
```
```python
from diffusers import UNet2DConditionModel
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
model = UNet2DConditionModel.from_single_file(ckpt_path, upcast_attention=True)
```
In the example above, since we explicitly passed `repo_id="segmind/SSD-1B"`, it will use this [configuration file](https://huggingface.co/segmind/SSD-1B/blob/main/unet/config.json) from the "unet" subfolder in `"segmind/SSD-1B"` to configure the unet component included in the checkpoint; Similarly, it will use the `config.json` file from `"vae"` subfolder to configure the vae model, `config.json` file from text_encoder folder to configure text_encoder and so on.
Note that most of the time you do not need to explicitly a `config` argument, `from_single_file` will automatically map the checkpoint to a repo id (we will discuss this in more details in next section). However, this can be useful in cases where model components might have been changed from what was originally distributed or in cases where a checkpoint file might not have the necessary metadata to correctly determine the configuration to use for the pipeline.
<Tip>
@@ -24,14 +135,114 @@ To learn more about how to load single file weights, see the [Load different Sta
</Tip>
## Working with local files
As of `diffusers>=0.28.0` the `from_single_file` method will attempt to configure a pipeline or model by first inferring the model type from the checkpoint file and then using the model type to determine the appropriate model repo configuration to use from the Hugging Face Hub. For example, any single file checkpoint based on the Stable Diffusion XL base model will use the [`stabilityai/stable-diffusion-xl-base-1.0`](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) model repo to configure the pipeline.
If you are working in an environment with restricted internet access, it is recommended to download the config files and checkpoints for the model to your preferred directory and pass the local paths to the `pretrained_model_link_or_path` and `config` arguments of the `from_single_file` method.
```python
from huggingface_hub import hf_hub_download, snapshot_download
my_local_checkpoint_path = hf_hub_download(
repo_id="segmind/SSD-1B",
filename="SSD-1B.safetensors"
)
my_local_config_path = snapshot_download(
repo_id="segmind/SSD-1B",
allowed_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
)
pipe = StableDiffusionXLPipeline.from_single_file(my_local_checkpoint_path, config=my_local_config_path, local_files_only=True)
```
By default this will download the checkpoints and config files to the [Hugging Face Hub cache directory](https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache). You can also specify a local directory to download the files to by passing the `local_dir` argument to the `hf_hub_download` and `snapshot_download` functions.
```python
from huggingface_hub import hf_hub_download, snapshot_download
my_local_checkpoint_path = hf_hub_download(
repo_id="segmind/SSD-1B",
filename="SSD-1B.safetensors"
local_dir="my_local_checkpoints"
)
my_local_config_path = snapshot_download(
repo_id="segmind/SSD-1B",
allowed_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
local_dir="my_local_config"
)
pipe = StableDiffusionXLPipeline.from_single_file(my_local_checkpoint_path, config=my_local_config_path, local_files_only=True)
```
## Working with local files on file systems that do not support symlinking
By default the `from_single_file` method relies on the `huggingface_hub` caching mechanism to fetch and store checkpoints and config files for models and pipelines. If you are working with a file system that does not support symlinking, it is recommended that you first download the checkpoint file to a local directory and disable symlinking by passing the `local_dir_use_symlink=False` argument to the `hf_hub_download` and `snapshot_download` functions.
```python
from huggingface_hub import hf_hub_download, snapshot_download
my_local_checkpoint_path = hf_hub_download(
repo_id="segmind/SSD-1B",
filename="SSD-1B.safetensors"
local_dir="my_local_checkpoints",
local_dir_use_symlinks=False
)
print("My local checkpoint: ", my_local_checkpoint_path)
my_local_config_path = snapshot_download(
repo_id="segmind/SSD-1B",
allowed_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
local_dir_use_symlinks=False,
)
print("My local config: ", my_local_config_path)
```
Then pass the local paths to the `pretrained_model_link_or_path` and `config` arguments of the `from_single_file` method.
```python
pipe = StableDiffusionXLPipeline.from_single_file(my_local_checkpoint_path, config=my_local_config_path, local_files_only=True)
```
<Tip>
Disabling symlinking means that the `huggingface_hub` caching mechanism has no way to determine whether a file has already been downloaded to the local directory. This means that the `hf_hub_download` and `snapshot_download` functions will download files to the local directory each time they are executed. If you are disabling symlinking, it is recommended that you separate the model download and loading steps to avoid downloading the same file multiple times.
</Tip>
## Using the original configuration file of a model
If you would like to configure the parameters of the model components in the pipeline using the orignal YAML configuration file, you can pass a local path or url to the original configuration file to the `original_config` argument of the `from_single_file` method.
```python
from diffusers import StableDiffusionXLPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
original_config = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
pipe = StableDiffusionXLPipeline.from_single_file(ckpt_path, original_config=original_config)
```
In the example above, the `original_config` file is only used to configure the parameters of the individual model components of the pipeline. For example it will be used to configure parameters such as the `in_channels` of the `vae` model and `unet` model. It is not used to determine the type of component objects in the pipeline.
<Tip>
When using `original_config` with local_files_only=True`, Diffusers will attempt to infer the components based on the type signatures of pipeline class, rather than attempting to fetch the pipeline config from the Hugging Face Hub. This is to prevent backwards breaking changes in existing code that might not be able to connect to the internet to fetch the necessary pipeline config files.
This is not as reliable as providing a path to a local config repo and might lead to errors when configuring the pipeline. To avoid this, please run the pipeline with `local_files_only=False` once to download the appropriate pipeline config files to the local cache.
</Tip>
## FromSingleFileMixin
[[autodoc]] loaders.single_file.FromSingleFileMixin
## FromOriginalVAEMixin
## FromOriginalModelMixin
[[autodoc]] loaders.autoencoder.FromOriginalVAEMixin
## FromOriginalControlnetMixin
[[autodoc]] loaders.controlnet.FromOriginalControlNetMixin
[[autodoc]] loaders.single_file_model.FromOriginalModelMixin

View File

@@ -27,6 +27,7 @@ from .utils import (
_import_structure = {
"configuration_utils": ["ConfigMixin"],
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
"schedulers": [],

View File

@@ -340,6 +340,8 @@ class ConfigMixin:
"""
cache_dir = kwargs.pop("cache_dir", None)
local_dir = kwargs.pop("local_dir", None)
local_dir_use_symlinks = kwargs.pop("local_dir_use_symlinks", "auto")
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", None)
proxies = kwargs.pop("proxies", None)
@@ -364,13 +366,13 @@ class ConfigMixin:
if os.path.isfile(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path
elif os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
# Load from a PyTorch checkpoint
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
elif subfolder is not None and os.path.isfile(
if subfolder is not None and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
):
config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
# Load from a PyTorch checkpoint
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
else:
raise EnvironmentError(
f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
@@ -390,6 +392,8 @@ class ConfigMixin:
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
local_dir=local_dir,
local_dir_use_symlinks=local_dir_use_symlinks,
)
except RepositoryNotFoundError:
raise EnvironmentError(

View File

@@ -54,9 +54,7 @@ if is_transformers_available():
_import_structure = {}
if is_torch_available():
_import_structure["autoencoder"] = ["FromOriginalVAEMixin"]
_import_structure["controlnet"] = ["FromOriginalControlNetMixin"]
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
_import_structure["utils"] = ["AttnProcsLayers"]
if is_transformers_available():
@@ -70,8 +68,7 @@ _import_structure["peft"] = ["PeftAdapterMixin"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .autoencoder import FromOriginalVAEMixin
from .controlnet import FromOriginalControlNetMixin
from .single_file_model import FromOriginalModelMixin
from .unet import UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers

View File

@@ -11,144 +11,243 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
import os
from huggingface_hub.utils import validate_hf_hub_args
import torch
from huggingface_hub import snapshot_download
from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args
from packaging import version
from ..utils import is_transformers_available, logging
from ..utils import deprecate, is_transformers_available, logging
from .single_file_utils import (
create_diffusers_unet_model_from_ldm,
create_diffusers_vae_model_from_ldm,
create_scheduler_from_ldm,
create_text_encoders_and_tokenizers_from_ldm,
fetch_ldm_config_and_checkpoint,
infer_model_type,
SingleFileComponentError,
_is_model_weights_in_cached_folder,
_legacy_load_clip_tokenizer,
_legacy_load_safety_checker,
_legacy_load_scheduler,
create_diffusers_clip_model_from_ldm,
fetch_diffusers_config,
fetch_original_config,
is_clip_model_in_single_file,
load_single_file_checkpoint,
)
logger = logging.get_logger(__name__)
# Pipelines that support the SDXL Refiner checkpoint
REFINER_PIPELINES = [
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLControlNetImg2ImgPipeline",
]
# Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided
SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"]
if is_transformers_available():
from transformers import AutoFeatureExtractor
import transformers
from transformers import PreTrainedModel, PreTrainedTokenizer
def build_sub_model_components(
pipeline_components,
pipeline_class_name,
component_name,
original_config,
def load_single_file_sub_model(
library_name,
class_name,
name,
checkpoint,
pipelines,
is_pipeline_module,
cached_model_config_path,
original_config=None,
local_files_only=False,
load_safety_checker=False,
model_type=None,
image_size=None,
torch_dtype=None,
is_legacy_loading=False,
**kwargs,
):
if component_name in pipeline_components:
return {}
if is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name)
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
if component_name == "unet":
num_in_channels = kwargs.pop("num_in_channels", None)
upcast_attention = kwargs.pop("upcast_attention", None)
if is_transformers_available():
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
else:
transformers_version = "N/A"
unet_components = create_diffusers_unet_model_from_ldm(
pipeline_class_name,
original_config,
checkpoint,
num_in_channels=num_in_channels,
image_size=image_size,
is_transformers_model = (
is_transformers_available()
and issubclass(class_obj, PreTrainedModel)
and transformers_version >= version.parse("4.20.0")
)
is_tokenizer = (
is_transformers_available()
and issubclass(class_obj, PreTrainedTokenizer)
and transformers_version >= version.parse("4.20.0")
)
diffusers_module = importlib.import_module(__name__.split(".")[0])
is_diffusers_single_file_model = issubclass(class_obj, diffusers_module.FromOriginalModelMixin)
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
is_diffusers_scheduler = issubclass(class_obj, diffusers_module.SchedulerMixin)
if is_diffusers_single_file_model:
load_method = getattr(class_obj, "from_single_file")
# We cannot provide two different config options to the `from_single_file` method
# Here we have to ignore loading the config from `cached_model_config_path` if `original_config` is provided
if original_config:
cached_model_config_path = None
loaded_sub_model = load_method(
pretrained_model_link_or_path_or_dict=checkpoint,
original_config=original_config,
config=cached_model_config_path,
subfolder=name,
torch_dtype=torch_dtype,
model_type=model_type,
upcast_attention=upcast_attention,
)
return unet_components
if component_name == "vae":
scaling_factor = kwargs.get("scaling_factor", None)
vae_components = create_diffusers_vae_model_from_ldm(
pipeline_class_name,
original_config,
checkpoint,
image_size,
scaling_factor,
torch_dtype,
model_type=model_type,
)
return vae_components
if component_name == "scheduler":
scheduler_type = kwargs.get("scheduler_type", "ddim")
prediction_type = kwargs.get("prediction_type", None)
scheduler_components = create_scheduler_from_ldm(
pipeline_class_name,
original_config,
checkpoint,
scheduler_type=scheduler_type,
prediction_type=prediction_type,
model_type=model_type,
)
return scheduler_components
if component_name in ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]:
text_encoder_components = create_text_encoders_and_tokenizers_from_ldm(
original_config,
checkpoint,
model_type=model_type,
local_files_only=local_files_only,
torch_dtype=torch_dtype,
**kwargs,
)
return text_encoder_components
if component_name == "safety_checker":
if load_safety_checker:
from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
elif is_transformers_model and is_clip_model_in_single_file(class_obj, checkpoint):
loaded_sub_model = create_diffusers_clip_model_from_ldm(
class_obj,
checkpoint=checkpoint,
config=cached_model_config_path,
subfolder=name,
torch_dtype=torch_dtype,
local_files_only=local_files_only,
is_legacy_loading=is_legacy_loading,
)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype
elif is_tokenizer and is_legacy_loading:
loaded_sub_model = _legacy_load_clip_tokenizer(
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
)
elif is_diffusers_scheduler and is_legacy_loading:
loaded_sub_model = _legacy_load_scheduler(
class_obj, checkpoint=checkpoint, component_name=name, original_config=original_config, **kwargs
)
else:
if not hasattr(class_obj, "from_pretrained"):
raise ValueError(
(
f"The component {class_obj.__name__} cannot be loaded as it does not seem to have"
" a supported loading method."
)
)
else:
safety_checker = None
return {"safety_checker": safety_checker}
if component_name == "feature_extractor":
if load_safety_checker:
feature_extractor = AutoFeatureExtractor.from_pretrained(
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
)
else:
feature_extractor = None
return {"feature_extractor": feature_extractor}
return
def set_additional_components(
pipeline_class_name,
original_config,
checkpoint=None,
model_type=None,
):
components = {}
if pipeline_class_name in REFINER_PIPELINES:
model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
is_refiner = model_type == "SDXL-Refiner"
components.update(
loading_kwargs = {}
loading_kwargs.update(
{
"requires_aesthetics_score": is_refiner,
"force_zeros_for_empty_prompt": False if is_refiner else True,
"pretrained_model_name_or_path": cached_model_config_path,
"subfolder": name,
"local_files_only": local_files_only,
}
)
return components
# Schedulers and Tokenizers don't make use of torch_dtype
# Skip passing it to those objects
if issubclass(class_obj, torch.nn.Module):
loading_kwargs.update({"torch_dtype": torch_dtype})
if is_diffusers_model or is_transformers_model:
if not _is_model_weights_in_cached_folder(cached_model_config_path, name):
raise SingleFileComponentError(
f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
)
load_method = getattr(class_obj, "from_pretrained")
loaded_sub_model = load_method(**loading_kwargs)
return loaded_sub_model
def _map_component_types_to_config_dict(component_types):
diffusers_module = importlib.import_module(__name__.split(".")[0])
config_dict = {}
component_types.pop("self", None)
if is_transformers_available():
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
else:
transformers_version = "N/A"
for component_name, component_value in component_types.items():
is_diffusers_model = issubclass(component_value[0], diffusers_module.ModelMixin)
is_scheduler_enum = component_value[0].__name__ == "KarrasDiffusionSchedulers"
is_scheduler = issubclass(component_value[0], diffusers_module.SchedulerMixin)
is_transformers_model = (
is_transformers_available()
and issubclass(component_value[0], PreTrainedModel)
and transformers_version >= version.parse("4.20.0")
)
is_transformers_tokenizer = (
is_transformers_available()
and issubclass(component_value[0], PreTrainedTokenizer)
and transformers_version >= version.parse("4.20.0")
)
if is_diffusers_model and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
config_dict[component_name] = ["diffusers", component_value[0].__name__]
elif is_scheduler_enum or is_scheduler:
if is_scheduler_enum:
# Since we cannot fetch a scheduler config from the hub, we default to DDIMScheduler
# if the type hint is a KarrassDiffusionSchedulers enum
config_dict[component_name] = ["diffusers", "DDIMScheduler"]
elif is_scheduler:
config_dict[component_name] = ["diffusers", component_value[0].__name__]
elif (
is_transformers_model or is_transformers_tokenizer
) and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
config_dict[component_name] = ["transformers", component_value[0].__name__]
else:
config_dict[component_name] = [None, None]
return config_dict
def _infer_pipeline_config_dict(pipeline_class):
parameters = inspect.signature(pipeline_class.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
component_types = pipeline_class._get_signature_types()
# Ignore parameters that are not required for the pipeline
component_types = {k: v for k, v in component_types.items() if k in required_parameters}
config_dict = _map_component_types_to_config_dict(component_types)
return config_dict
def _download_diffusers_model_config_from_hub(
pretrained_model_name_or_path,
cache_dir,
revision,
proxies,
force_download=None,
resume_download=None,
local_files_only=None,
token=None,
):
allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt"]
cached_model_path = snapshot_download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
revision=revision,
proxies=proxies,
force_download=force_download,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
allow_patterns=allow_patterns,
)
return cached_model_path
class FromSingleFileMixin:
@@ -195,27 +294,12 @@ class FromSingleFileMixin:
original_config_file (`str`, *optional*):
The path to the original config file that was used to train the model. If not provided, the config file
will be inferred from the checkpoint file.
model_type (`str`, *optional*):
The type of model to load. If not provided, the model type will be inferred from the checkpoint file.
image_size (`int`, *optional*):
The size of the image output. It's used to configure the `sample_size` parameter of the UNet and VAE
model.
load_safety_checker (`bool`, *optional*, defaults to `False`):
Whether to load the safety checker model or not. By default, the safety checker is not loaded unless a
`safety_checker` component is passed to the `kwargs`.
num_in_channels (`int`, *optional*):
Specify the number of input channels for the UNet model. Read more about how to configure UNet model
with this parameter
[here](https://huggingface.co/docs/diffusers/training/adapt_a_model#configure-unet2dconditionmodel-parameters).
scaling_factor (`float`, *optional*):
The scaling factor to use for the VAE model. If not provided, it is inferred from the config file
first. If the scaling factor is not found in the config file, the default value 0.18215 is used.
scheduler_type (`str`, *optional*):
The type of scheduler to load. If not provided, the scheduler type will be inferred from the checkpoint
file.
prediction_type (`str`, *optional*):
The type of prediction to load. If not provided, the prediction type will be inferred from the
checkpoint file.
config (`str`, *optional*):
Can be either:
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
hosted on the Hub.
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
component configs in Diffusers format.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
@@ -233,7 +317,7 @@ class FromSingleFileMixin:
>>> # Download pipeline from local file
>>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
>>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly")
>>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly.ckpt")
>>> # Enable float16 and move to GPU
>>> pipeline = StableDiffusionPipeline.from_single_file(
@@ -242,9 +326,21 @@ class FromSingleFileMixin:
... )
>>> pipeline.to("cuda")
```
"""
original_config_file = kwargs.pop("original_config_file", None)
resume_download = kwargs.pop("resume_download", None)
config = kwargs.pop("config", None)
original_config = kwargs.pop("original_config", None)
if original_config_file is not None:
deprecation_message = (
"`original_config_file` argument is deprecated and will be removed in future versions."
"please use the `original_config` argument instead."
)
deprecate("original_config_file", "1.0.0", deprecation_message)
original_config = original_config_file
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
@@ -253,68 +349,198 @@ class FromSingleFileMixin:
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
class_name = cls.__name__
is_legacy_loading = False
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
pretrained_model_link_or_path=pretrained_model_link_or_path,
class_name=class_name,
original_config_file=original_config_file,
# We shouldn't allow configuring individual models components through a Pipeline creation method
# These model kwargs should be deprecated
scaling_factor = kwargs.get("scaling_factor", None)
if scaling_factor is not None:
deprecation_message = (
"Passing the `scaling_factor` argument to `from_single_file is deprecated "
"and will be ignored in future versions."
)
deprecate("scaling_factor", "1.0.0", deprecation_message)
if original_config is not None:
original_config = fetch_original_config(original_config, local_files_only=local_files_only)
from ..pipelines.pipeline_utils import _get_pipeline_class
pipeline_class = _get_pipeline_class(cls, config=None)
checkpoint = load_single_file_checkpoint(
pretrained_model_link_or_path,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
cache_dir=cache_dir,
local_files_only=local_files_only,
cache_dir=cache_dir,
revision=revision,
)
from ..pipelines.pipeline_utils import _get_pipeline_class
if config is None:
config = fetch_diffusers_config(checkpoint)
default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
else:
default_pretrained_model_config_name = config
pipeline_class = _get_pipeline_class(
cls,
config=None,
cache_dir=cache_dir,
)
if not os.path.isdir(default_pretrained_model_config_name):
# Provided config is a repo_id
if default_pretrained_model_config_name.count("/") > 1:
raise ValueError(
f'The provided config "{config}"'
" is neither a valid local path nor a valid repo id. Please check the parameter."
)
try:
# Attempt to download the config files for the pipeline
cached_model_config_path = _download_diffusers_model_config_from_hub(
default_pretrained_model_config_name,
cache_dir=cache_dir,
revision=revision,
proxies=proxies,
force_download=force_download,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
)
config_dict = pipeline_class.load_config(cached_model_config_path)
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
except LocalEntryNotFoundError:
# `local_files_only=True` but a local diffusers format model config is not available in the cache
# If `original_config` is not provided, we need override `local_files_only` to False
# to fetch the config files from the hub so that we have a way
# to configure the pipeline components.
if original_config is None:
logger.warning(
"`local_files_only` is True but no local configs were found for this checkpoint.\n"
"Attempting to download the necessary config files for this pipeline.\n"
)
cached_model_config_path = _download_diffusers_model_config_from_hub(
default_pretrained_model_config_name,
cache_dir=cache_dir,
revision=revision,
proxies=proxies,
force_download=force_download,
resume_download=resume_download,
local_files_only=False,
token=token,
)
config_dict = pipeline_class.load_config(cached_model_config_path)
else:
# For backwards compatibility
# If `original_config` is provided, then we need to assume we are using legacy loading for pipeline components
logger.warning(
"Detected legacy `from_single_file` loading behavior. Attempting to create the pipeline based on inferred components.\n"
"This may lead to errors if the model components are not correctly inferred. \n"
"To avoid this warning, please explicity pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n"
"e.g. `from_single_file(<my model checkpoint path>, config=<path to local diffusers model repo>) \n"
"or run `from_single_file` with `local_files_only=False` first to update the local cache directory with "
"the necessary config files.\n"
)
is_legacy_loading = True
cached_model_config_path = None
config_dict = _infer_pipeline_config_dict(pipeline_class)
config_dict["_class_name"] = pipeline_class.__name__
else:
# Provided config is a path to a local directory attempt to load directly.
cached_model_config_path = default_pretrained_model_config_name
config_dict = pipeline_class.load_config(cached_model_config_path)
# pop out "_ignore_files" as it is only needed for download
config_dict.pop("_ignore_files", None)
expected_modules, optional_kwargs = pipeline_class._get_signature_keys(cls)
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
model_type = kwargs.pop("model_type", None)
image_size = kwargs.pop("image_size", None)
load_safety_checker = (kwargs.pop("load_safety_checker", False)) or (
passed_class_obj.get("safety_checker", None) is not None
)
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
from diffusers import pipelines
# remove `null` components
def load_module(name, value):
if value[0] is None:
return False
if name in passed_class_obj and passed_class_obj[name] is None:
return False
if name in SINGLE_FILE_OPTIONAL_COMPONENTS:
return False
return True
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
for name, (library_name, class_name) in logging.tqdm(
sorted(init_dict.items()), desc="Loading pipeline components..."
):
loaded_sub_model = None
is_pipeline_module = hasattr(pipelines, library_name)
init_kwargs = {}
for name in expected_modules:
if name in passed_class_obj:
init_kwargs[name] = passed_class_obj[name]
loaded_sub_model = passed_class_obj[name]
else:
components = build_sub_model_components(
init_kwargs,
class_name,
name,
original_config,
checkpoint,
model_type=model_type,
image_size=image_size,
load_safety_checker=load_safety_checker,
local_files_only=local_files_only,
torch_dtype=torch_dtype,
**kwargs,
)
if not components:
continue
init_kwargs.update(components)
try:
loaded_sub_model = load_single_file_sub_model(
library_name=library_name,
class_name=class_name,
name=name,
checkpoint=checkpoint,
is_pipeline_module=is_pipeline_module,
cached_model_config_path=cached_model_config_path,
pipelines=pipelines,
torch_dtype=torch_dtype,
original_config=original_config,
local_files_only=local_files_only,
is_legacy_loading=is_legacy_loading,
**kwargs,
)
except SingleFileComponentError as e:
raise SingleFileComponentError(
(
f"{e.message}\n"
f"Please load the component before passing it in as an argument to `from_single_file`.\n"
f"\n"
f"{name} = {class_name}.from_pretrained('...')\n"
f"pipe = {pipeline_class.__name__}.from_single_file(<checkpoint path>, {name}={name})\n"
f"\n"
)
)
additional_components = set_additional_components(
class_name, original_config, checkpoint=checkpoint, model_type=model_type
)
if additional_components:
init_kwargs.update(additional_components)
init_kwargs[name] = loaded_sub_model
missing_modules = set(expected_modules) - set(init_kwargs.keys())
passed_modules = list(passed_class_obj.keys())
optional_modules = pipeline_class._optional_components
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
for module in missing_modules:
init_kwargs[module] = passed_class_obj.get(module, None)
elif len(missing_modules) > 0:
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
raise ValueError(
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
)
# deprecated kwargs
load_safety_checker = kwargs.pop("load_safety_checker", None)
if load_safety_checker is not None:
deprecation_message = (
"Please pass instances of `StableDiffusionSafetyChecker` and `AutoImageProcessor`"
"using the `safety_checker` and `feature_extractor` arguments in `from_single_file`"
)
deprecate("load_safety_checker", "1.0.0", deprecation_message)
safety_checker_components = _legacy_load_safety_checker(local_files_only, torch_dtype)
init_kwargs.update(safety_checker_components)
init_kwargs.update(passed_pipe_kwargs)
pipe = pipeline_class(**init_kwargs)
if torch_dtype is not None:

View File

@@ -0,0 +1,290 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import re
from contextlib import nullcontext
from typing import Optional
from huggingface_hub.utils import validate_hf_hub_args
from ..utils import deprecate, is_accelerate_available, logging
from .single_file_utils import (
SingleFileComponentError,
convert_controlnet_checkpoint,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
convert_stable_cascade_unet_single_file_to_diffusers,
create_controlnet_diffusers_config_from_ldm,
create_unet_diffusers_config_from_ldm,
create_vae_diffusers_config_from_ldm,
fetch_diffusers_config,
fetch_original_config,
load_single_file_checkpoint,
)
logger = logging.get_logger(__name__)
if is_accelerate_available():
from accelerate import init_empty_weights
from ..models.modeling_utils import load_model_dict_into_meta
SINGLE_FILE_LOADABLE_CLASSES = {
"StableCascadeUNet": {
"checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers,
},
"UNet2DConditionModel": {
"checkpoint_mapping_fn": convert_ldm_unet_checkpoint,
"config_mapping_fn": create_unet_diffusers_config_from_ldm,
"default_subfolder": "unet",
"legacy_kwargs": {
"num_in_channels": "in_channels", # Legacy kwargs supported by `from_single_file` mapped to new args
},
},
"AutoencoderKL": {
"checkpoint_mapping_fn": convert_ldm_vae_checkpoint,
"config_mapping_fn": create_vae_diffusers_config_from_ldm,
"default_subfolder": "vae",
},
"ControlNetModel": {
"checkpoint_mapping_fn": convert_controlnet_checkpoint,
"config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
},
}
def _get_mapping_function_kwargs(mapping_fn, **kwargs):
parameters = inspect.signature(mapping_fn).parameters
mapping_kwargs = {}
for parameter in parameters:
if parameter in kwargs:
mapping_kwargs[parameter] = kwargs[parameter]
return mapping_kwargs
class FromOriginalModelMixin:
"""
Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model.
"""
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs):
r"""
Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model
is set in evaluation mode (`model.eval()`) by default.
Parameters:
pretrained_model_link_or_path_or_dict (`str`, *optional*):
Can be either:
- A link to the `.safetensors` or `.ckpt` file (for example
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.safetensors"`) on the Hub.
- A path to a local *file* containing the weights of the component model.
- A state dict containing the component model weights.
config (`str`, *optional*):
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted
on the Hub.
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component
configs in Diffusers format.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
original_config (`str`, *optional*):
Dict or path to a yaml file containing the configuration for the model in its original format.
If a dict is provided, it will be used to initialize the model configuration.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (for example the pipeline components of the
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
method. See example below for more information.
```py
>>> from diffusers import StableCascadeUNet
>>> ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors"
>>> model = StableCascadeUNet.from_single_file(ckpt_path)
```
"""
class_name = cls.__name__
if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
raise ValueError(
f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}"
)
pretrained_model_link_or_path = kwargs.get("pretrained_model_link_or_path", None)
if pretrained_model_link_or_path is not None:
deprecation_message = (
"Please use `pretrained_model_link_or_path_or_dict` argument instead for model classes"
)
deprecate("pretrained_model_link_or_path", "1.0.0", deprecation_message)
pretrained_model_link_or_path_or_dict = pretrained_model_link_or_path
config = kwargs.pop("config", None)
original_config = kwargs.pop("original_config", None)
if config is not None and original_config is not None:
raise ValueError(
"`from_single_file` cannot accept both `config` and `original_config` arguments. Please provide only one of these arguments"
)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", None)
subfolder = kwargs.pop("subfolder", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
if isinstance(pretrained_model_link_or_path_or_dict, dict):
checkpoint = pretrained_model_link_or_path_or_dict
else:
checkpoint = load_single_file_checkpoint(
pretrained_model_link_or_path_or_dict,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[class_name]
checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
if original_config:
if "config_mapping_fn" in mapping_functions:
config_mapping_fn = mapping_functions["config_mapping_fn"]
else:
config_mapping_fn = None
if config_mapping_fn is None:
raise ValueError(
(
f"`original_config` has been provided for {class_name} but no mapping function"
"was found to convert the original config to a Diffusers config in"
"`diffusers.loaders.single_file_utils`"
)
)
if isinstance(original_config, str):
# If original_config is a URL or filepath fetch the original_config dict
original_config = fetch_original_config(original_config, local_files_only=local_files_only)
config_mapping_kwargs = _get_mapping_function_kwargs(config_mapping_fn, **kwargs)
diffusers_model_config = config_mapping_fn(
original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs
)
else:
if config:
if isinstance(config, str):
default_pretrained_model_config_name = config
else:
raise ValueError(
(
"Invalid `config` argument. Please provide a string representing a repo id"
"or path to a local Diffusers model repo."
)
)
else:
config = fetch_diffusers_config(checkpoint)
default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
if "default_subfolder" in mapping_functions:
subfolder = mapping_functions["default_subfolder"]
subfolder = subfolder or config.pop(
"subfolder", None
) # some configs contain a subfolder key, e.g. StableCascadeUNet
diffusers_model_config = cls.load_config(
pretrained_model_name_or_path=default_pretrained_model_config_name,
subfolder=subfolder,
local_files_only=local_files_only,
)
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
# Map legacy kwargs to new kwargs
if "legacy_kwargs" in mapping_functions:
legacy_kwargs = mapping_functions["legacy_kwargs"]
for legacy_key, new_key in legacy_kwargs.items():
if legacy_key in kwargs:
kwargs[new_key] = kwargs.pop(legacy_key)
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
diffusers_model_config.update(model_kwargs)
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
diffusers_format_checkpoint = checkpoint_mapping_fn(
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
)
if not diffusers_format_checkpoint:
raise SingleFileComponentError(
f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
model = cls.from_config(diffusers_model_config)
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
if model._keys_to_ignore_on_load_unexpected is not None:
for pat in model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
model.load_state_dict(diffusers_format_checkpoint)
if torch_dtype is not None:
model.to(torch_dtype)
model.eval()
return model

File diff suppressed because it is too large Load Diff

View File

@@ -44,11 +44,6 @@ from ..utils import (
set_adapter_layers,
set_weights_and_activate_adapters,
)
from .single_file_utils import (
convert_stable_cascade_unet_single_file_to_diffusers,
infer_stable_cascade_single_file_config,
load_single_file_model_checkpoint,
)
from .unet_loader_utils import _maybe_expand_lora_scales
from .utils import AttnProcsLayers
@@ -1059,103 +1054,3 @@ class UNet2DConditionLoadersMixin:
}
)
return lora_dicts
class FromOriginalUNetMixin:
"""
Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`StableCascadeUNet`].
"""
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`StableCascadeUNet`] from pretrained StableCascadeUNet weights saved in the original `.ckpt` or
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
Parameters:
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A link to the `.ckpt` file (for example
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
- A path to a *file* containing all pipeline weights.
config: (`dict`, *optional*):
Dictionary containing the configuration of the model:
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
resume_download:
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
of Diffusers.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables of the model.
"""
class_name = cls.__name__
if class_name != "StableCascadeUNet":
raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet")
config = kwargs.pop("config", None)
resume_download = kwargs.pop("resume_download", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
checkpoint = load_single_file_model_checkpoint(
pretrained_model_link_or_path,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)
if config is None:
config = infer_stable_cascade_single_file_config(checkpoint)
model_config = cls.load_config(**config, **kwargs)
else:
model_config = config
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
model = cls.from_config(model_config, **kwargs)
diffusers_format_checkpoint = convert_stable_cascade_unet_single_file_to_diffusers(checkpoint)
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
model.load_state_dict(diffusers_format_checkpoint)
if torch_dtype is not None:
model.to(torch_dtype)
return model

View File

@@ -17,7 +17,7 @@ import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalVAEMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils.accelerate_utils import apply_forward_hook
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
@@ -32,7 +32,7 @@ from ..modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.

View File

@@ -19,7 +19,7 @@ from torch import nn
from torch.nn import functional as F
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalControlNetMixin
from ..loaders.single_file_model import FromOriginalModelMixin
from ..utils import BaseOutput, logging
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
@@ -108,7 +108,7 @@ class ControlNetConditioningEmbedding(nn.Module):
return embedding
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
"""
A ControlNet model.

View File

@@ -963,6 +963,15 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
@classmethod
def _get_signature_keys(cls, obj):
parameters = inspect.signature(obj.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
expected_modules = set(required_parameters.keys()) - {"self"}
return expected_modules, optional_parameters
# Adapted from `transformers` modeling_utils.py
def _get_no_split_modules(self, device_map: str):
"""

View File

@@ -20,6 +20,7 @@ import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ..activations import get_activation
from ..attention_processor import (
@@ -66,7 +67,9 @@ class UNet2DConditionOutput(BaseOutput):
sample: torch.FloatTensor = None
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
class UNet2DConditionModel(
ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin
):
r"""
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
shaped output.

View File

@@ -21,7 +21,7 @@ import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders.unet import FromOriginalUNetMixin
from ...loaders import FromOriginalModelMixin
from ...utils import BaseOutput
from ..attention_processor import Attention
from ..modeling_utils import ModelMixin
@@ -134,7 +134,7 @@ class StableCascadeUNetOutput(BaseOutput):
sample: torch.FloatTensor = None
class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin):
class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
_supports_gradient_checkpointing = True
@register_to_config

View File

@@ -609,6 +609,7 @@ def load_sub_model(
):
"""Helper method to load the module `name` from `library_name` and `class_name`"""
# retrieve class candidates
class_obj, class_candidates = get_class_obj_and_candidates(
library_name,
class_name,

View File

@@ -791,62 +791,6 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
tolerance = 3e-3 if torch_device != "mps" else 1e-2
assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
def test_stable_diffusion_model_local(self):
model_id = "stabilityai/sd-vae-ft-mse"
model_1 = AutoencoderKL.from_pretrained(model_id).to(torch_device)
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
model_2 = AutoencoderKL.from_single_file(url).to(torch_device)
image = self.get_sd_image(33)
with torch.no_grad():
sample_1 = model_1(image).sample
sample_2 = model_2(image).sample
assert sample_1.shape == sample_2.shape
output_slice_1 = sample_1[-1, -2:, -2:, :2].flatten().float().cpu()
output_slice_2 = sample_2[-1, -2:, -2:, :2].flatten().float().cpu()
assert torch_all_close(output_slice_1, output_slice_2, atol=3e-3)
def test_single_file_component_configs(self):
vae_single_file = AutoencoderKL.from_single_file(
"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
)
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"]
for param_name, param_value in vae_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
vae.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"
def test_single_file_arguments(self):
vae_default = AutoencoderKL.from_single_file(
"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors",
)
assert vae_default.config.scaling_factor == 0.18215
assert vae_default.config.sample_size == 512
assert vae_default.dtype == torch.float32
scaling_factor = 2.0
image_size = 256
torch_dtype = torch.float16
vae = AutoencoderKL.from_single_file(
"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors",
image_size=image_size,
scaling_factor=scaling_factor,
torch_dtype=torch_dtype,
)
assert vae.config.scaling_factor == scaling_factor
assert vae.config.sample_size == image_size
assert vae.dtype == torch_dtype
@slow
class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):

View File

@@ -56,7 +56,7 @@ class StableCascadeUNetModelSlowTests(unittest.TestCase):
gc.collect()
torch.cuda.empty_cache()
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"]
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in single_file_unet_config.items():
if param_name in PARAMS_TO_IGNORE:
continue
@@ -78,7 +78,7 @@ class StableCascadeUNetModelSlowTests(unittest.TestCase):
gc.collect()
torch.cuda.empty_cache()
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"]
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in single_file_unet_config.items():
if param_name in PARAMS_TO_IGNORE:
continue
@@ -97,7 +97,7 @@ class StableCascadeUNetModelSlowTests(unittest.TestCase):
gc.collect()
torch.cuda.empty_cache()
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values"]
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in config.items():
if param_name in PARAMS_TO_IGNORE:
continue

View File

@@ -38,7 +38,6 @@ from diffusers.utils.testing_utils import (
get_python_version,
load_image,
load_numpy,
numpy_cosine_similarity_distance,
require_python39_or_higher,
require_torch_2,
require_torch_gpu,
@@ -1063,97 +1062,6 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
expected_slice = np.array([0.1338, 0.1597, 0.1202, 0.1687, 0.1377, 0.1017, 0.2070, 0.1574, 0.1348])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_load_local(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
pipe.unet.set_default_attn_processor()
pipe.enable_model_cpu_offload()
controlnet = ControlNetModel.from_single_file(
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
)
pipe_sf = StableDiffusionControlNetPipeline.from_single_file(
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
safety_checker=None,
controlnet=controlnet,
scheduler_type="pndm",
)
pipe_sf.unet.set_default_attn_processor()
pipe_sf.enable_model_cpu_offload()
control_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
).resize((512, 512))
prompt = "bird"
generator = torch.Generator(device="cpu").manual_seed(0)
output = pipe(
prompt,
image=control_image,
generator=generator,
output_type="np",
num_inference_steps=3,
).images[0]
generator = torch.Generator(device="cpu").manual_seed(0)
output_sf = pipe_sf(
prompt,
image=control_image,
generator=generator,
output_type="np",
num_inference_steps=3,
).images[0]
max_diff = numpy_cosine_similarity_distance(output_sf.flatten(), output.flatten())
assert max_diff < 1e-3
def test_single_file_component_configs(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", variant="fp16")
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", variant="fp16", safety_checker=None, controlnet=controlnet
)
controlnet_single_file = ControlNetModel.from_single_file(
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
)
single_file_pipe = StableDiffusionControlNetPipeline.from_single_file(
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
safety_checker=None,
controlnet=controlnet_single_file,
scheduler_type="pndm",
)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "architectures", "_use_default_values"]
for param_name, param_value in single_file_pipe.controlnet.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
# This parameter doesn't appear to be loaded from the config.
# So when it is registered to config, it remains a tuple as this is the default in the class definition
# from_pretrained, does load from config and converts to a list when registering to config
if param_name == "conditioning_embedding_out_channels" and isinstance(param_value, tuple):
param_value = list(param_value)
assert (
pipe.controlnet.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"
for param_name, param_value in single_file_pipe.unet.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
pipe.unet.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"
for param_name, param_value in single_file_pipe.vae.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
pipe.vae.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"
@slow
@require_torch_gpu

View File

@@ -39,7 +39,6 @@ from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
load_numpy,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
@@ -441,56 +440,3 @@ class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase):
)
assert np.abs(expected_image - image).max() < 9e-2
def test_load_local(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
)
pipe.unet.set_default_attn_processor()
pipe.enable_model_cpu_offload()
controlnet = ControlNetModel.from_single_file(
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
)
pipe_sf = StableDiffusionControlNetImg2ImgPipeline.from_single_file(
"https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors",
safety_checker=None,
controlnet=controlnet,
scheduler_type="pndm",
)
pipe_sf.unet.set_default_attn_processor()
pipe_sf.enable_model_cpu_offload()
control_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
).resize((512, 512))
image = load_image(
"https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png"
).resize((512, 512))
prompt = "bird"
generator = torch.Generator(device="cpu").manual_seed(0)
output = pipe(
prompt,
image=image,
control_image=control_image,
strength=0.9,
generator=generator,
output_type="np",
num_inference_steps=3,
).images[0]
generator = torch.Generator(device="cpu").manual_seed(0)
output_sf = pipe_sf(
prompt,
image=image,
control_image=control_image,
strength=0.9,
generator=generator,
output_type="np",
num_inference_steps=3,
).images[0]
max_diff = numpy_cosine_similarity_distance(output_sf.flatten(), output.flatten())
assert max_diff < 1e-3

View File

@@ -556,55 +556,3 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase):
)
assert numpy_cosine_similarity_distance(expected_image.flatten(), image.flatten()) < 1e-2
def test_load_local(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
pipe_1 = StableDiffusionControlNetInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting", safety_checker=None, controlnet=controlnet
)
controlnet = ControlNetModel.from_single_file(
"https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
)
pipe_2 = StableDiffusionControlNetInpaintPipeline.from_single_file(
"https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt",
safety_checker=None,
controlnet=controlnet,
)
control_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
).resize((512, 512))
image = load_image(
"https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png"
).resize((512, 512))
mask_image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_inpaint/input_bench_mask.png"
).resize((512, 512))
pipes = [pipe_1, pipe_2]
images = []
for pipe in pipes:
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "bird"
output = pipe(
prompt,
image=image,
control_image=control_image,
mask_image=mask_image,
strength=0.9,
generator=generator,
output_type="np",
num_inference_steps=3,
)
images.append(output.images[0])
del pipe
gc.collect()
torch.cuda.empty_cache()
max_diff = numpy_cosine_similarity_distance(images[0].flatten(), images[1].flatten())
assert max_diff < 1e-3

View File

@@ -37,7 +37,6 @@ from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
enable_full_determinism,
load_image,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
@@ -949,89 +948,6 @@ class ControlNetSDXLPipelineSlowTests(unittest.TestCase):
expected_image = np.array([0.4399, 0.5112, 0.5478, 0.4314, 0.472, 0.4823, 0.4647, 0.4957, 0.4853])
assert np.allclose(original_image, expected_image, atol=1e-04)
def test_download_ckpt_diff_format_is_same(self):
controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16)
single_file_url = (
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
)
pipe_single_file = StableDiffusionXLControlNetPipeline.from_single_file(
single_file_url, controlnet=controlnet, torch_dtype=torch.float16
)
pipe_single_file.unet.set_default_attn_processor()
pipe_single_file.enable_model_cpu_offload()
pipe_single_file.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "Stormtrooper's lecture"
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png"
)
single_file_images = pipe_single_file(
prompt, image=image, generator=generator, output_type="np", num_inference_steps=2
).images
generator = torch.Generator(device="cpu").manual_seed(0)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.unet.set_default_attn_processor()
pipe.enable_model_cpu_offload()
images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=2).images
assert images[0].shape == (512, 512, 3)
assert single_file_images[0].shape == (512, 512, 3)
max_diff = numpy_cosine_similarity_distance(images[0].flatten(), single_file_images[0].flatten())
assert max_diff < 5e-2
def test_single_file_component_configs(self):
controlnet = ControlNetModel.from_pretrained(
"diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
variant="fp16",
controlnet=controlnet,
torch_dtype=torch.float16,
)
single_file_url = (
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
)
single_file_pipe = StableDiffusionXLControlNetPipeline.from_single_file(
single_file_url, controlnet=controlnet, torch_dtype=torch.float16
)
for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items():
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
continue
assert pipe.text_encoder.config.to_dict()[param_name] == param_value
for param_name, param_value in single_file_pipe.text_encoder_2.config.to_dict().items():
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
continue
assert pipe.text_encoder_2.config.to_dict()[param_name] == param_value
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "architectures", "_use_default_values"]
for param_name, param_value in single_file_pipe.unet.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
# Upcast attention might be set to None in a config file, which is incorrect. It should default to False in the model
if param_name == "upcast_attention" and pipe.unet.config[param_name] is None:
pipe.unet.config[param_name] = False
assert (
pipe.unet.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"
for param_name, param_value in single_file_pipe.vae.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
pipe.vae.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"
class StableDiffusionSSD1BControlNetPipelineFastTests(StableDiffusionXLControlNetPipelineFastTests):
def test_controlnet_sdxl_guess(self):

View File

@@ -42,7 +42,6 @@ from diffusers import (
UNet2DConditionModel,
logging,
)
from diffusers.models.attention_processor import AttnProcessor
from diffusers.utils.testing_utils import (
CaptureLogger,
enable_full_determinism,
@@ -1284,62 +1283,6 @@ class StableDiffusionPipelineCkptTests(unittest.TestCase):
assert image_out.shape == (512, 512, 3)
def test_download_ckpt_diff_format_is_same(self):
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
sf_pipe = StableDiffusionPipeline.from_single_file(ckpt_path)
sf_pipe.scheduler = DDIMScheduler.from_config(sf_pipe.scheduler.config)
sf_pipe.unet.set_attn_processor(AttnProcessor())
sf_pipe.to("cuda")
generator = torch.Generator(device="cpu").manual_seed(0)
image_single_file = sf_pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0]
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.unet.set_attn_processor(AttnProcessor())
pipe.to("cuda")
generator = torch.Generator(device="cpu").manual_seed(0)
image = pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0]
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten())
assert max_diff < 1e-3
def test_single_file_component_configs(self):
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
single_file_pipe = StableDiffusionPipeline.from_single_file(ckpt_path, load_safety_checker=True)
for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items():
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
continue
assert pipe.text_encoder.config.to_dict()[param_name] == param_value
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "architectures", "_use_default_values"]
for param_name, param_value in single_file_pipe.unet.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
pipe.unet.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"
for param_name, param_value in single_file_pipe.vae.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
pipe.vae.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"
for param_name, param_value in single_file_pipe.safety_checker.config.to_dict().items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
pipe.safety_checker.config.to_dict()[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"
@nightly
@require_torch_gpu

View File

@@ -36,7 +36,6 @@ from diffusers import (
StableDiffusionInpaintPipeline,
UNet2DConditionModel,
)
from diffusers.models.attention_processor import AttnProcessor
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
@@ -44,7 +43,6 @@ from diffusers.utils.testing_utils import (
load_image,
load_numpy,
nightly,
numpy_cosine_similarity_distance,
require_python39_or_higher,
require_torch_2,
require_torch_gpu,
@@ -786,77 +784,6 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
expected_slice = np.array([0.3757, 0.3875, 0.4445, 0.4353, 0.3780, 0.4513, 0.3965, 0.3984, 0.4362])
assert np.abs(expected_slice - image_slice).max() < 1e-3
def test_download_local(self):
filename = hf_hub_download("runwayml/stable-diffusion-inpainting", filename="sd-v1-5-inpainting.ckpt")
pipe = StableDiffusionInpaintPipeline.from_single_file(filename, torch_dtype=torch.float16)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 1
image_out = pipe(**inputs).images[0]
assert image_out.shape == (512, 512, 3)
def test_download_ckpt_diff_format_is_same(self):
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
pipe = StableDiffusionInpaintPipeline.from_single_file(ckpt_path)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.unet.set_attn_processor(AttnProcessor())
pipe.to("cuda")
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 5
image_ckpt = pipe(**inputs).images[0]
pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.unet.set_attn_processor(AttnProcessor())
pipe.to("cuda")
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 5
image = pipe(**inputs).images[0]
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_ckpt.flatten())
assert max_diff < 1e-4
def test_single_file_component_configs(self):
pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", variant="fp16")
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
single_file_pipe = StableDiffusionInpaintPipeline.from_single_file(ckpt_path, load_safety_checker=True)
for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items():
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
continue
assert pipe.text_encoder.config.to_dict()[param_name] == param_value
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "architectures", "_use_default_values"]
for param_name, param_value in single_file_pipe.unet.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
pipe.unet.config[param_name] == param_value
), f"{param_name} is differs between single file loading and pretrained loading"
for param_name, param_value in single_file_pipe.vae.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
pipe.vae.config[param_name] == param_value
), f"{param_name} is differs between single file loading and pretrained loading"
for param_name, param_value in single_file_pipe.safety_checker.config.to_dict().items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
pipe.safety_checker.config.to_dict()[param_name] == param_value
), f"{param_name} is differs between single file loading and pretrained loading"
@slow
@require_torch_gpu
@@ -1082,9 +1009,6 @@ class StableDiffusionInpaintPipelineAsymmetricAutoencoderKLSlowTests(unittest.Te
assert image_out.shape == (512, 512, 3)
def test_download_ckpt_diff_format_is_same(self):
pass
@nightly
@require_torch_gpu

View File

@@ -29,7 +29,6 @@ from diffusers.utils.testing_utils import (
floats_tensor,
load_image,
load_numpy,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
@@ -492,73 +491,3 @@ class StableDiffusionUpscalePipelineIntegrationTests(unittest.TestCase):
mem_bytes = torch.cuda.max_memory_allocated()
# make sure that less than 2.9 GB is allocated
assert mem_bytes < 2.9 * 10**9
def test_download_ckpt_diff_format_is_same(self):
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/sd2-upscale/low_res_cat.png"
)
prompt = "a cat sitting on a park bench"
model_id = "stabilityai/stable-diffusion-x4-upscaler"
pipe = StableDiffusionUpscalePipeline.from_pretrained(model_id)
pipe.enable_model_cpu_offload()
generator = torch.Generator("cpu").manual_seed(0)
output = pipe(prompt=prompt, image=image, generator=generator, output_type="np", num_inference_steps=3)
image_from_pretrained = output.images[0]
single_file_path = (
"https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/blob/main/x4-upscaler-ema.safetensors"
)
pipe_from_single_file = StableDiffusionUpscalePipeline.from_single_file(single_file_path)
pipe_from_single_file.enable_model_cpu_offload()
generator = torch.Generator("cpu").manual_seed(0)
output_from_single_file = pipe_from_single_file(
prompt=prompt, image=image, generator=generator, output_type="np", num_inference_steps=3
)
image_from_single_file = output_from_single_file.images[0]
assert image_from_pretrained.shape == (512, 512, 3)
assert image_from_single_file.shape == (512, 512, 3)
assert (
numpy_cosine_similarity_distance(image_from_pretrained.flatten(), image_from_single_file.flatten()) < 1e-3
)
def test_single_file_component_configs(self):
pipe = StableDiffusionUpscalePipeline.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler", variant="fp16"
)
ckpt_path = (
"https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/blob/main/x4-upscaler-ema.safetensors"
)
single_file_pipe = StableDiffusionUpscalePipeline.from_single_file(ckpt_path, load_safety_checker=True)
for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items():
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
continue
assert pipe.text_encoder.config.to_dict()[param_name] == param_value
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "architectures", "_use_default_values"]
for param_name, param_value in single_file_pipe.unet.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
pipe.unet.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"
for param_name, param_value in single_file_pipe.vae.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
pipe.vae.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"
for param_name, param_value in single_file_pipe.safety_checker.config.to_dict().items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
pipe.safety_checker.config.to_dict()[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"

View File

@@ -30,7 +30,6 @@ from diffusers import (
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.models.attention_processor import AttnProcessor
from diffusers.utils.testing_utils import (
enable_full_determinism,
load_numpy,
@@ -473,30 +472,6 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
assert image_out.shape == (768, 768, 3)
def test_download_ckpt_diff_format_is_same(self):
single_file_path = (
"https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors"
)
pipe_single = StableDiffusionPipeline.from_single_file(single_file_path)
pipe_single.scheduler = DDIMScheduler.from_config(pipe_single.scheduler.config)
pipe_single.unet.set_attn_processor(AttnProcessor())
pipe_single.enable_model_cpu_offload()
generator = torch.Generator(device="cpu").manual_seed(0)
image_ckpt = pipe_single("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0]
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.unet.set_attn_processor(AttnProcessor())
pipe.enable_model_cpu_offload()
generator = torch.Generator(device="cpu").manual_seed(0)
image = pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0]
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_ckpt.flatten())
assert max_diff < 1e-3
def test_stable_diffusion_text2img_intermediate_state_v_pred(self):
number_of_steps = 0

View File

@@ -1046,68 +1046,3 @@ class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):
max_diff = numpy_cosine_similarity_distance(image.flatten(), expected_image.flatten())
assert max_diff < 1e-2
def test_download_ckpt_diff_format_is_same(self):
ckpt_path = (
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
)
pipe = StableDiffusionXLPipeline.from_single_file(ckpt_path, torch_dtype=torch.float16)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.unet.set_default_attn_processor()
pipe.enable_model_cpu_offload()
generator = torch.Generator(device="cpu").manual_seed(0)
image_ckpt = pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0]
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.unet.set_default_attn_processor()
pipe.enable_model_cpu_offload()
generator = torch.Generator(device="cpu").manual_seed(0)
image = pipe("a turtle", num_inference_steps=2, generator=generator, output_type="np").images[0]
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_ckpt.flatten())
assert max_diff < 6e-3
def test_single_file_component_configs(self):
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
)
ckpt_path = (
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
)
single_file_pipe = StableDiffusionXLPipeline.from_single_file(
ckpt_path, variant="fp16", torch_dtype=torch.float16
)
for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items():
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
continue
assert pipe.text_encoder.config.to_dict()[param_name] == param_value
for param_name, param_value in single_file_pipe.text_encoder_2.config.to_dict().items():
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
continue
assert pipe.text_encoder_2.config.to_dict()[param_name] == param_value
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "architectures", "_use_default_values"]
for param_name, param_value in single_file_pipe.unet.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
if param_name == "upcast_attention" and pipe.unet.config[param_name] is None:
pipe.unet.config[param_name] = False
assert (
pipe.unet.config[param_name] == param_value
), f"{param_name} is differs between single file loading and pretrained loading"
for param_name, param_value in single_file_pipe.vae.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
pipe.vae.config[param_name] == param_value
), f"{param_name} is differs between single file loading and pretrained loading"

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import random
import unittest
@@ -32,13 +31,10 @@ from diffusers import (
T2IAdapter,
UNet2DConditionModel,
)
from diffusers.utils import load_image, logging
from diffusers.utils import logging
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
)
@@ -678,54 +674,3 @@ class StableDiffusionXLMultiAdapterPipelineFastTests(
print(",".join(debug))
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
@require_torch_gpu
class AdapterSDXLPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_download_ckpt_diff_format_is_same(self):
ckpt_path = (
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
)
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16)
prompt = "toy"
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png"
)
pipe_single_file = StableDiffusionXLAdapterPipeline.from_single_file(
ckpt_path,
adapter=adapter,
torch_dtype=torch.float16,
)
pipe_single_file.enable_model_cpu_offload()
pipe_single_file.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
images_single_file = pipe_single_file(
prompt, image=image, generator=generator, output_type="np", num_inference_steps=3
).images
generator = torch.Generator(device="cpu").manual_seed(0)
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
adapter=adapter,
torch_dtype=torch.float16,
)
pipe.enable_model_cpu_offload()
images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
assert images_single_file[0].shape == (768, 512, 3)
assert images[0].shape == (768, 512, 3)
max_diff = numpy_cosine_similarity_distance(images[0].flatten(), images_single_file[0].flatten())
assert max_diff < 5e-3

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import random
import unittest
@@ -32,19 +31,15 @@ from transformers import (
from diffusers import (
AutoencoderKL,
AutoencoderTiny,
DDIMScheduler,
EulerDiscreteScheduler,
LCMScheduler,
StableDiffusionXLImg2ImgPipeline,
UNet2DConditionModel,
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
)
@@ -781,85 +776,3 @@ class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests(
def test_save_load_optional_components(self):
self._test_save_load_optional_components()
@slow
class StableDiffusionXLImg2ImgIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_download_ckpt_diff_format_is_same(self):
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/main/sd_xl_refiner_1.0.safetensors"
init_image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_img2img/sketch-mountains-input.png"
)
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16
)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.unet.set_default_attn_processor()
pipe.enable_model_cpu_offload()
generator = torch.Generator(device="cpu").manual_seed(0)
image = pipe(
prompt="mountains", image=init_image, num_inference_steps=5, generator=generator, output_type="np"
).images[0]
pipe_single_file = StableDiffusionXLImg2ImgPipeline.from_single_file(ckpt_path, torch_dtype=torch.float16)
pipe_single_file.scheduler = DDIMScheduler.from_config(pipe_single_file.scheduler.config)
pipe_single_file.unet.set_default_attn_processor()
pipe_single_file.enable_model_cpu_offload()
generator = torch.Generator(device="cpu").manual_seed(0)
image_single_file = pipe_single_file(
prompt="mountains", image=init_image, num_inference_steps=5, generator=generator, output_type="np"
).images[0]
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten())
assert max_diff < 5e-2
def test_single_file_component_configs(self):
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
torch_dtype=torch.float16,
variant="fp16",
)
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/main/sd_xl_refiner_1.0.safetensors"
single_file_pipe = StableDiffusionXLImg2ImgPipeline.from_single_file(ckpt_path, torch_dtype=torch.float16)
assert pipe.text_encoder is None
assert single_file_pipe.text_encoder is None
for param_name, param_value in single_file_pipe.text_encoder_2.config.to_dict().items():
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
continue
assert pipe.text_encoder_2.config.to_dict()[param_name] == param_value
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "architectures", "_use_default_values"]
for param_name, param_value in single_file_pipe.unet.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
if param_name == "upcast_attention" and pipe.unet.config[param_name] is None:
pipe.unet.config[param_name] = False
assert (
pipe.unet.config[param_name] == param_value
), f"{param_name} is differs between single file loading and pretrained loading"
for param_name, param_value in single_file_pipe.vae.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
pipe.vae.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"

View File

View File

@@ -0,0 +1,380 @@
import tempfile
from io import BytesIO
import requests
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from diffusers.models.attention_processor import AttnProcessor
from diffusers.utils.testing_utils import (
numpy_cosine_similarity_distance,
torch_device,
)
def download_single_file_checkpoint(repo_id, filename, tmpdir):
path = hf_hub_download(repo_id, filename=filename, local_dir=tmpdir)
return path
def download_original_config(config_url, tmpdir):
original_config_file = BytesIO(requests.get(config_url).content)
path = f"{tmpdir}/config.yaml"
with open(path, "wb") as f:
f.write(original_config_file.read())
return path
def download_diffusers_config(repo_id, tmpdir):
path = snapshot_download(
repo_id,
ignore_patterns=[
"**/*.ckpt",
"*.ckpt",
"**/*.bin",
"*.bin",
"**/*.pt",
"*.pt",
"**/*.safetensors",
"*.safetensors",
],
allow_patterns=["**/*.json", "*.json", "*.txt", "**/*.txt"],
local_dir=tmpdir,
)
return path
class SDSingleFileTesterMixin:
def _compare_component_configs(self, pipe, single_file_pipe):
for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items():
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
continue
assert pipe.text_encoder.config.to_dict()[param_name] == param_value
PARAMS_TO_IGNORE = [
"torch_dtype",
"_name_or_path",
"architectures",
"_use_default_values",
"_diffusers_version",
]
for component_name, component in single_file_pipe.components.items():
if component_name in single_file_pipe._optional_components:
continue
# skip testing transformer based components here
# skip text encoders / safety checkers since they have already been tested
if component_name in ["text_encoder", "tokenizer", "safety_checker", "feature_extractor"]:
continue
assert component_name in pipe.components, f"single file {component_name} not found in pretrained pipeline"
assert isinstance(
component, pipe.components[component_name].__class__
), f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same"
for param_name, param_value in component.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
# Some pretrained configs will set upcast attention to None
# In single file loading it defaults to the value in the class __init__ which is False
if param_name == "upcast_attention" and pipe.components[component_name].config[param_name] is None:
pipe.components[component_name].config[param_name] = param_value
assert (
pipe.components[component_name].config[param_name] == param_value
), f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}"
def test_single_file_components(self, pipe=None, single_file_pipe=None):
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
self.ckpt_path, safety_checker=None
)
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
self._compare_component_configs(pipe, single_file_pipe)
def test_single_file_components_local_files_only(self, pipe=None, single_file_pipe=None):
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
local_ckpt_path, safety_checker=None, local_files_only=True
)
self._compare_component_configs(pipe, single_file_pipe)
def test_single_file_components_with_original_config(
self,
pipe=None,
single_file_pipe=None,
):
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
# Not possible to infer this value when original config is provided
# we just pass it in here otherwise this test will fail
upcast_attention = pipe.unet.config.upcast_attention
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
self.ckpt_path,
original_config=self.original_config,
safety_checker=None,
upcast_attention=upcast_attention,
)
self._compare_component_configs(pipe, single_file_pipe)
def test_single_file_components_with_original_config_local_files_only(
self,
pipe=None,
single_file_pipe=None,
):
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
# Not possible to infer this value when original config is provided
# we just pass it in here otherwise this test will fail
upcast_attention = pipe.unet.config.upcast_attention
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
local_original_config = download_original_config(self.original_config, tmpdir)
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
local_ckpt_path,
original_config=local_original_config,
safety_checker=None,
upcast_attention=upcast_attention,
local_files_only=True,
)
self._compare_component_configs(pipe, single_file_pipe)
def test_single_file_format_inference_is_same_as_pretrained(self, expected_max_diff=1e-4):
sf_pipe = self.pipeline_class.from_single_file(self.ckpt_path, safety_checker=None)
sf_pipe.unet.set_attn_processor(AttnProcessor())
sf_pipe.enable_model_cpu_offload()
inputs = self.get_inputs(torch_device)
image_single_file = sf_pipe(**inputs).images[0]
pipe = self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
pipe.unet.set_attn_processor(AttnProcessor())
pipe.enable_model_cpu_offload()
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten())
assert max_diff < expected_max_diff
def test_single_file_components_with_diffusers_config(
self,
pipe=None,
single_file_pipe=None,
):
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
self.ckpt_path, config=self.repo_id, safety_checker=None
)
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
self._compare_component_configs(pipe, single_file_pipe)
def test_single_file_components_with_diffusers_config_local_files_only(
self,
pipe=None,
single_file_pipe=None,
):
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir)
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
local_ckpt_path, config=local_diffusers_config, safety_checker=None, local_files_only=True
)
self._compare_component_configs(pipe, single_file_pipe)
class SDXLSingleFileTesterMixin:
def _compare_component_configs(self, pipe, single_file_pipe):
# Skip testing the text_encoder for Refiner Pipelines
if pipe.text_encoder:
for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items():
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
continue
assert pipe.text_encoder.config.to_dict()[param_name] == param_value
for param_name, param_value in single_file_pipe.text_encoder_2.config.to_dict().items():
if param_name in ["torch_dtype", "architectures", "_name_or_path"]:
continue
assert pipe.text_encoder_2.config.to_dict()[param_name] == param_value
PARAMS_TO_IGNORE = [
"torch_dtype",
"_name_or_path",
"architectures",
"_use_default_values",
"_diffusers_version",
]
for component_name, component in single_file_pipe.components.items():
if component_name in single_file_pipe._optional_components:
continue
# skip text encoders since they have already been tested
if component_name in ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]:
continue
# skip safety checker if it is not present in the pipeline
if component_name in ["safety_checker", "feature_extractor"]:
continue
assert component_name in pipe.components, f"single file {component_name} not found in pretrained pipeline"
assert isinstance(
component, pipe.components[component_name].__class__
), f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same"
for param_name, param_value in component.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
# Some pretrained configs will set upcast attention to None
# In single file loading it defaults to the value in the class __init__ which is False
if param_name == "upcast_attention" and pipe.components[component_name].config[param_name] is None:
pipe.components[component_name].config[param_name] = param_value
assert (
pipe.components[component_name].config[param_name] == param_value
), f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}"
def test_single_file_components(self, pipe=None, single_file_pipe=None):
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
self.ckpt_path, safety_checker=None
)
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
self._compare_component_configs(
pipe,
single_file_pipe,
)
def test_single_file_components_local_files_only(
self,
pipe=None,
single_file_pipe=None,
):
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
local_ckpt_path, safety_checker=None, local_files_only=True
)
self._compare_component_configs(pipe, single_file_pipe)
def test_single_file_components_with_original_config(
self,
pipe=None,
single_file_pipe=None,
):
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
# Not possible to infer this value when original config is provided
# we just pass it in here otherwise this test will fail
upcast_attention = pipe.unet.config.upcast_attention
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
self.ckpt_path,
original_config=self.original_config,
safety_checker=None,
upcast_attention=upcast_attention,
)
self._compare_component_configs(
pipe,
single_file_pipe,
)
def test_single_file_components_with_original_config_local_files_only(
self,
pipe=None,
single_file_pipe=None,
):
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
# Not possible to infer this value when original config is provided
# we just pass it in here otherwise this test will fail
upcast_attention = pipe.unet.config.upcast_attention
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
local_original_config = download_original_config(self.original_config, tmpdir)
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
local_ckpt_path,
original_config=local_original_config,
upcast_attention=upcast_attention,
safety_checker=None,
local_files_only=True,
)
self._compare_component_configs(
pipe,
single_file_pipe,
)
def test_single_file_components_with_diffusers_config(
self,
pipe=None,
single_file_pipe=None,
):
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
self.ckpt_path, config=self.repo_id, safety_checker=None
)
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
self._compare_component_configs(pipe, single_file_pipe)
def test_single_file_components_with_diffusers_config_local_files_only(
self,
pipe=None,
single_file_pipe=None,
):
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir)
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
local_ckpt_path, config=local_diffusers_config, safety_checker=None, local_files_only=True
)
self._compare_component_configs(pipe, single_file_pipe)
def test_single_file_format_inference_is_same_as_pretrained(self, expected_max_diff=1e-4):
sf_pipe = self.pipeline_class.from_single_file(self.ckpt_path, torch_dtype=torch.float16, safety_checker=None)
sf_pipe.unet.set_default_attn_processor()
sf_pipe.enable_model_cpu_offload()
inputs = self.get_inputs(torch_device)
image_single_file = sf_pipe(**inputs).images[0]
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16, safety_checker=None)
pipe.unet.set_default_attn_processor()
pipe.enable_model_cpu_offload()
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten())
assert max_diff < expected_max_diff

View File

@@ -0,0 +1,78 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import torch
from diffusers import (
ControlNetModel,
)
from diffusers.utils.testing_utils import (
enable_full_determinism,
require_torch_gpu,
slow,
)
enable_full_determinism()
@slow
@require_torch_gpu
class ControlNetModelSingleFileTests(unittest.TestCase):
model_class = ControlNetModel
ckpt_path = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
repo_id = "lllyasviel/control_v11p_sd15_canny"
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_single_file_components(self):
model = self.model_class.from_pretrained(self.repo_id)
model_single_file = self.model_class.from_single_file(self.ckpt_path)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
model.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"
def test_single_file_arguments(self):
model_default = self.model_class.from_single_file(self.ckpt_path)
assert model_default.config.upcast_attention is False
assert model_default.dtype == torch.float32
torch_dtype = torch.float16
upcast_attention = True
model = self.model_class.from_single_file(
self.ckpt_path,
upcast_attention=upcast_attention,
torch_dtype=torch_dtype,
)
assert model.config.upcast_attention == upcast_attention
assert model.dtype == torch_dtype

View File

@@ -0,0 +1,114 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import torch
from diffusers import StableCascadeUNet
from diffusers.utils import logging
from diffusers.utils.testing_utils import (
enable_full_determinism,
require_torch_gpu,
slow,
)
logger = logging.get_logger(__name__)
enable_full_determinism()
@slow
@require_torch_gpu
class StableCascadeUNetSingleFileTest(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_single_file_components_stage_b(self):
model_single_file = StableCascadeUNet.from_single_file(
"https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_bf16.safetensors",
torch_dtype=torch.bfloat16,
)
model = StableCascadeUNet.from_pretrained(
"stabilityai/stable-cascade", variant="bf16", subfolder="decoder", use_safetensors=True
)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
model.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"
def test_single_file_components_stage_b_lite(self):
model_single_file = StableCascadeUNet.from_single_file(
"https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite_bf16.safetensors",
torch_dtype=torch.bfloat16,
)
model = StableCascadeUNet.from_pretrained(
"stabilityai/stable-cascade", variant="bf16", subfolder="decoder_lite"
)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
model.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"
def test_single_file_components_stage_c(self):
model_single_file = StableCascadeUNet.from_single_file(
"https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_c_bf16.safetensors",
torch_dtype=torch.bfloat16,
)
model = StableCascadeUNet.from_pretrained(
"stabilityai/stable-cascade-prior", variant="bf16", subfolder="prior"
)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
model.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"
def test_single_file_components_stage_c_lite(self):
model_single_file = StableCascadeUNet.from_single_file(
"https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_c_lite_bf16.safetensors",
torch_dtype=torch.bfloat16,
)
model = StableCascadeUNet.from_pretrained(
"stabilityai/stable-cascade-prior", variant="bf16", subfolder="prior_lite"
)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
model.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"

View File

@@ -0,0 +1,117 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import torch
from diffusers import (
AutoencoderKL,
)
from diffusers.utils.testing_utils import (
enable_full_determinism,
load_hf_numpy,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
)
enable_full_determinism()
@slow
@require_torch_gpu
class AutoencoderKLSingleFileTests(unittest.TestCase):
model_class = AutoencoderKL
ckpt_path = (
"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
)
repo_id = "stabilityai/sd-vae-ft-mse"
main_input_name = "sample"
base_precision = 1e-2
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
dtype = torch.float16 if fp16 else torch.float32
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
return image
def test_single_file_inference_same_as_pretrained(self):
model_1 = self.model_class.from_pretrained(self.repo_id).to(torch_device)
model_2 = self.model_class.from_single_file(self.ckpt_path, config=self.repo_id).to(torch_device)
image = self.get_sd_image(33)
generator = torch.Generator(torch_device)
with torch.no_grad():
sample_1 = model_1(image, generator=generator.manual_seed(0)).sample
sample_2 = model_2(image, generator=generator.manual_seed(0)).sample
assert sample_1.shape == sample_2.shape
output_slice_1 = sample_1.flatten().float().cpu()
output_slice_2 = sample_2.flatten().float().cpu()
assert numpy_cosine_similarity_distance(output_slice_1, output_slice_2) < 1e-4
def test_single_file_components(self):
model = self.model_class.from_pretrained(self.repo_id)
model_single_file = self.model_class.from_single_file(self.ckpt_path, config=self.repo_id)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
model.config[param_name] == param_value
), f"{param_name} differs between pretrained loading and single file loading"
def test_single_file_arguments(self):
model_default = self.model_class.from_single_file(self.ckpt_path, config=self.repo_id)
assert model_default.config.scaling_factor == 0.18215
assert model_default.config.sample_size == 256
assert model_default.dtype == torch.float32
scaling_factor = 2.0
sample_size = 512
torch_dtype = torch.float16
model = self.model_class.from_single_file(
self.ckpt_path,
config=self.repo_id,
sample_size=sample_size,
scaling_factor=scaling_factor,
torch_dtype=torch_dtype,
)
assert model.config.scaling_factor == scaling_factor
assert model.config.sample_size == sample_size
assert model.dtype == torch_dtype

View File

@@ -0,0 +1,182 @@
import gc
import tempfile
import unittest
import torch
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
)
from .single_file_testing_utils import (
SDSingleFileTesterMixin,
download_diffusers_config,
download_original_config,
download_single_file_checkpoint,
)
enable_full_determinism()
@slow
@require_torch_gpu
class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionControlNetPipeline
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
original_config = (
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
)
repo_id = "runwayml/stable-diffusion-v1-5"
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
init_image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_img2img/sketch-mountains-input.png"
)
control_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
).resize((512, 512))
prompt = "bird"
inputs = {
"prompt": prompt,
"image": init_image,
"control_image": control_image,
"generator": generator,
"num_inference_steps": 3,
"strength": 0.75,
"guidance_scale": 7.5,
"output_type": "np",
}
return inputs
def test_single_file_format_inference_is_same_as_pretrained(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
pipe.unet.set_default_attn_processor()
pipe.enable_model_cpu_offload()
pipe_sf = self.pipeline_class.from_single_file(
self.ckpt_path,
controlnet=controlnet,
)
pipe_sf.unet.set_default_attn_processor()
pipe_sf.enable_model_cpu_offload()
inputs = self.get_inputs(torch_device)
output = pipe(**inputs).images[0]
inputs = self.get_inputs(torch_device)
output_sf = pipe_sf(**inputs).images[0]
max_diff = numpy_cosine_similarity_distance(output_sf.flatten(), output.flatten())
assert max_diff < 1e-3
def test_single_file_components(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
pipe = self.pipeline_class.from_pretrained(
self.repo_id, variant="fp16", safety_checker=None, controlnet=controlnet
)
pipe_single_file = self.pipeline_class.from_single_file(
self.ckpt_path,
safety_checker=None,
controlnet=controlnet,
)
super()._compare_component_configs(pipe, pipe_single_file)
def test_single_file_components_local_files_only(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
pipe_single_file = self.pipeline_class.from_single_file(
local_ckpt_path, controlnet=controlnet, safety_checker=None, local_files_only=True
)
super()._compare_component_configs(pipe, pipe_single_file)
def test_single_file_components_with_original_config(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", variant="fp16")
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
pipe_single_file = self.pipeline_class.from_single_file(
self.ckpt_path, controlnet=controlnet, safety_checker=None, original_config=self.original_config
)
super()._compare_component_configs(pipe, pipe_single_file)
def test_single_file_components_with_original_config_local_files_only(self):
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
)
pipe = self.pipeline_class.from_pretrained(
self.repo_id,
controlnet=controlnet,
)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
local_original_config = download_original_config(self.original_config, tmpdir)
pipe_single_file = self.pipeline_class.from_single_file(
local_ckpt_path,
original_config=local_original_config,
controlnet=controlnet,
safety_checker=None,
local_files_only=True,
)
super()._compare_component_configs(pipe, pipe_single_file)
def test_single_file_components_with_diffusers_config(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", variant="fp16")
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
pipe_single_file = self.pipeline_class.from_single_file(
self.ckpt_path, controlnet=controlnet, safety_checker=None, original_config=self.original_config
)
super()._compare_component_configs(pipe, pipe_single_file)
def test_single_file_components_with_diffusers_config_local_files_only(self):
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
)
pipe = self.pipeline_class.from_pretrained(
self.repo_id,
controlnet=controlnet,
)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir)
pipe_single_file = self.pipeline_class.from_single_file(
local_ckpt_path,
config=local_diffusers_config,
safety_checker=None,
controlnet=controlnet,
local_files_only=True,
)
super()._compare_component_configs(pipe, pipe_single_file)

View File

@@ -0,0 +1,183 @@
import gc
import tempfile
import unittest
import torch
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
)
from .single_file_testing_utils import (
SDSingleFileTesterMixin,
download_diffusers_config,
download_original_config,
download_single_file_checkpoint,
)
enable_full_determinism()
@slow
@require_torch_gpu
class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionControlNetInpaintPipeline
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
original_config = "https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml"
repo_id = "runwayml/stable-diffusion-inpainting"
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self):
control_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
).resize((512, 512))
image = load_image(
"https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png"
).resize((512, 512))
mask_image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_inpaint/input_bench_mask.png"
).resize((512, 512))
inputs = {
"prompt": "bird",
"image": image,
"control_image": control_image,
"mask_image": mask_image,
"generator": torch.Generator(device="cpu").manual_seed(0),
"num_inference_steps": 3,
"output_type": "np",
}
return inputs
def test_single_file_format_inference_is_same_as_pretrained(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet, safety_checker=None)
pipe.unet.set_default_attn_processor()
pipe.enable_model_cpu_offload()
pipe_sf = self.pipeline_class.from_single_file(self.ckpt_path, controlnet=controlnet, safety_checker=None)
pipe_sf.unet.set_default_attn_processor()
pipe_sf.enable_model_cpu_offload()
inputs = self.get_inputs()
output = pipe(**inputs).images[0]
inputs = self.get_inputs()
output_sf = pipe_sf(**inputs).images[0]
max_diff = numpy_cosine_similarity_distance(output_sf.flatten(), output.flatten())
assert max_diff < 1e-3
def test_single_file_components(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
pipe = self.pipeline_class.from_pretrained(
self.repo_id, variant="fp16", safety_checker=None, controlnet=controlnet
)
pipe_single_file = self.pipeline_class.from_single_file(
self.ckpt_path,
safety_checker=None,
controlnet=controlnet,
)
super()._compare_component_configs(pipe, pipe_single_file)
def test_single_file_components_local_files_only(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
pipe = self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None, controlnet=controlnet)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
pipe_single_file = self.pipeline_class.from_single_file(
local_ckpt_path, controlnet=controlnet, safety_checker=None, local_files_only=True
)
super()._compare_component_configs(pipe, pipe_single_file)
def test_single_file_components_with_original_config(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", variant="fp16")
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
pipe_single_file = self.pipeline_class.from_single_file(
self.ckpt_path, controlnet=controlnet, original_config=self.original_config
)
super()._compare_component_configs(pipe, pipe_single_file)
def test_single_file_components_with_original_config_local_files_only(self):
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
)
pipe = self.pipeline_class.from_pretrained(
self.repo_id,
controlnet=controlnet,
safety_checker=None,
)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
local_original_config = download_original_config(self.original_config, tmpdir)
pipe_single_file = self.pipeline_class.from_single_file(
local_ckpt_path,
original_config=local_original_config,
controlnet=controlnet,
safety_checker=None,
local_files_only=True,
)
super()._compare_component_configs(pipe, pipe_single_file)
def test_single_file_components_with_diffusers_config(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", variant="fp16")
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
pipe_single_file = self.pipeline_class.from_single_file(
self.ckpt_path,
controlnet=controlnet,
config=self.repo_id,
)
super()._compare_component_configs(pipe, pipe_single_file)
def test_single_file_components_with_diffusers_config_local_files_only(self):
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_canny",
torch_dtype=torch.float16,
variant="fp16",
)
pipe = self.pipeline_class.from_pretrained(
self.repo_id,
controlnet=controlnet,
safety_checker=None,
)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir)
pipe_single_file = self.pipeline_class.from_single_file(
local_ckpt_path,
config=local_diffusers_config,
controlnet=controlnet,
safety_checker=None,
local_files_only=True,
)
super()._compare_component_configs(pipe, pipe_single_file)

View File

@@ -0,0 +1,171 @@
import gc
import tempfile
import unittest
import torch
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
)
from .single_file_testing_utils import (
SDSingleFileTesterMixin,
download_diffusers_config,
download_original_config,
download_single_file_checkpoint,
)
enable_full_determinism()
@slow
@require_torch_gpu
class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionControlNetPipeline
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
original_config = (
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
)
repo_id = "runwayml/stable-diffusion-v1-5"
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self):
control_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
).resize((512, 512))
inputs = {
"prompt": "bird",
"image": control_image,
"generator": torch.Generator(device="cpu").manual_seed(0),
"num_inference_steps": 3,
"output_type": "np",
}
return inputs
def test_single_file_format_inference_is_same_as_pretrained(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
pipe.unet.set_default_attn_processor()
pipe.enable_model_cpu_offload()
pipe_sf = self.pipeline_class.from_single_file(
self.ckpt_path,
controlnet=controlnet,
)
pipe_sf.unet.set_default_attn_processor()
pipe_sf.enable_model_cpu_offload()
inputs = self.get_inputs()
output = pipe(**inputs).images[0]
inputs = self.get_inputs()
output_sf = pipe_sf(**inputs).images[0]
max_diff = numpy_cosine_similarity_distance(output_sf.flatten(), output.flatten())
assert max_diff < 1e-3
def test_single_file_components(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
pipe = self.pipeline_class.from_pretrained(
self.repo_id, variant="fp16", safety_checker=None, controlnet=controlnet
)
pipe_single_file = self.pipeline_class.from_single_file(
self.ckpt_path,
safety_checker=None,
controlnet=controlnet,
)
super()._compare_component_configs(pipe, pipe_single_file)
def test_single_file_components_local_files_only(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
pipe_single_file = self.pipeline_class.from_single_file(
local_ckpt_path, controlnet=controlnet, local_files_only=True
)
super()._compare_component_configs(pipe, pipe_single_file)
def test_single_file_components_with_original_config(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", variant="fp16")
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
pipe_single_file = self.pipeline_class.from_single_file(
self.ckpt_path, controlnet=controlnet, original_config=self.original_config
)
super()._compare_component_configs(pipe, pipe_single_file)
def test_single_file_components_with_original_config_local_files_only(self):
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
)
pipe = self.pipeline_class.from_pretrained(
self.repo_id,
controlnet=controlnet,
)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
local_original_config = download_original_config(self.original_config, tmpdir)
pipe_single_file = self.pipeline_class.from_single_file(
local_ckpt_path, original_config=local_original_config, controlnet=controlnet, local_files_only=True
)
pipe_single_file.scheduler = pipe.scheduler
super()._compare_component_configs(pipe, pipe_single_file)
def test_single_file_components_with_diffusers_config(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", variant="fp16")
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
pipe_single_file = self.pipeline_class.from_single_file(
self.ckpt_path, controlnet=controlnet, safety_checker=None, config=self.repo_id
)
super()._compare_component_configs(pipe, pipe_single_file)
def test_single_file_components_with_diffusers_config_local_files_only(self):
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
)
pipe = self.pipeline_class.from_pretrained(
self.repo_id,
controlnet=controlnet,
safety_checker=None,
)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir)
pipe_single_file = self.pipeline_class.from_single_file(
local_ckpt_path,
config=local_diffusers_config,
controlnet=controlnet,
safety_checker=None,
local_files_only=True,
)
super()._compare_component_configs(pipe, pipe_single_file)

View File

@@ -0,0 +1,99 @@
import gc
import unittest
import torch
from diffusers import (
StableDiffusionImg2ImgPipeline,
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
require_torch_gpu,
slow,
)
from .single_file_testing_utils import SDSingleFileTesterMixin
enable_full_determinism()
@slow
@require_torch_gpu
class StableDiffusionImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionImg2ImgPipeline
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
original_config = (
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
)
repo_id = "runwayml/stable-diffusion-v1-5"
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
init_image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_img2img/sketch-mountains-input.png"
)
inputs = {
"prompt": "a fantasy landscape, concept art, high resolution",
"image": init_image,
"generator": generator,
"num_inference_steps": 3,
"strength": 0.75,
"guidance_scale": 7.5,
"output_type": "np",
}
return inputs
def test_single_file_format_inference_is_same_as_pretrained(self):
super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3)
@slow
@require_torch_gpu
class StableDiffusion21Img2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionImg2ImgPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors"
original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
repo_id = "stabilityai/stable-diffusion-2-1"
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
init_image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_img2img/sketch-mountains-input.png"
)
inputs = {
"prompt": "a fantasy landscape, concept art, high resolution",
"image": init_image,
"generator": generator,
"num_inference_steps": 3,
"strength": 0.75,
"guidance_scale": 7.5,
"output_type": "np",
}
return inputs
def test_single_file_format_inference_is_same_as_pretrained(self):
super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3)

View File

@@ -0,0 +1,114 @@
import gc
import unittest
import torch
from diffusers import (
StableDiffusionInpaintPipeline,
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
require_torch_gpu,
slow,
)
from .single_file_testing_utils import SDSingleFileTesterMixin
enable_full_determinism()
@slow
@require_torch_gpu
class StableDiffusionInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionInpaintPipeline
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
original_config = "https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml"
repo_id = "runwayml/stable-diffusion-inpainting"
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
init_image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_inpaint/input_bench_image.png"
)
mask_image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_inpaint/input_bench_mask.png"
)
inputs = {
"prompt": "Face of a yellow cat, high resolution, sitting on a park bench",
"image": init_image,
"mask_image": mask_image,
"generator": generator,
"num_inference_steps": 3,
"guidance_scale": 7.5,
"output_type": "np",
}
return inputs
def test_single_file_format_inference_is_same_as_pretrained(self):
super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3)
def test_single_file_loading_4_channel_unet(self):
# Test loading single file inpaint with a 4 channel UNet
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
pipe = self.pipeline_class.from_single_file(ckpt_path)
assert pipe.unet.config.in_channels == 4
@slow
@require_torch_gpu
class StableDiffusion21InpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionInpaintPipeline
ckpt_path = (
"https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/blob/main/512-inpainting-ema.safetensors"
)
original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inpainting-inference.yaml"
repo_id = "stabilityai/stable-diffusion-2-inpainting"
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
init_image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_inpaint/input_bench_image.png"
)
mask_image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_inpaint/input_bench_mask.png"
)
inputs = {
"prompt": "Face of a yellow cat, high resolution, sitting on a park bench",
"image": init_image,
"mask_image": mask_image,
"generator": generator,
"num_inference_steps": 3,
"guidance_scale": 7.5,
"output_type": "np",
}
return inputs
def test_single_file_format_inference_is_same_as_pretrained(self):
super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3)

View File

@@ -0,0 +1,114 @@
import gc
import tempfile
import unittest
import torch
from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline
from diffusers.utils.testing_utils import (
enable_full_determinism,
require_torch_gpu,
slow,
)
from .single_file_testing_utils import (
SDSingleFileTesterMixin,
download_original_config,
download_single_file_checkpoint,
)
enable_full_determinism()
@slow
@require_torch_gpu
class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionPipeline
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
original_config = (
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
)
repo_id = "runwayml/stable-diffusion-v1-5"
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
inputs = {
"prompt": "a fantasy landscape, concept art, high resolution",
"generator": generator,
"num_inference_steps": 2,
"strength": 0.75,
"guidance_scale": 7.5,
"output_type": "np",
}
return inputs
def test_single_file_format_inference_is_same_as_pretrained(self):
super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3)
def test_single_file_legacy_scheduler_loading(self):
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
local_original_config = download_original_config(self.original_config, tmpdir)
pipe = self.pipeline_class.from_single_file(
local_ckpt_path,
original_config=local_original_config,
cache_dir=tmpdir,
local_files_only=True,
scheduler_type="euler",
)
# Default is PNDM for this checkpoint
assert isinstance(pipe.scheduler, EulerDiscreteScheduler)
def test_single_file_legacy_scaling_factor(self):
new_scaling_factor = 10.0
init_pipe = self.pipeline_class.from_single_file(self.ckpt_path)
pipe = self.pipeline_class.from_single_file(self.ckpt_path, scaling_factor=new_scaling_factor)
assert init_pipe.vae.config.scaling_factor != new_scaling_factor
assert pipe.vae.config.scaling_factor == new_scaling_factor
class StableDiffusion21PipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors"
original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
repo_id = "stabilityai/stable-diffusion-2-1"
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
inputs = {
"prompt": "a fantasy landscape, concept art, high resolution",
"generator": generator,
"num_inference_steps": 2,
"strength": 0.75,
"guidance_scale": 7.5,
"output_type": "np",
}
return inputs
def test_single_file_format_inference_is_same_as_pretrained(self):
super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3)

View File

@@ -0,0 +1,68 @@
import gc
import unittest
import torch
from diffusers import (
StableDiffusionUpscalePipeline,
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
)
from .single_file_testing_utils import SDSingleFileTesterMixin
enable_full_determinism()
@slow
@require_torch_gpu
class StableDiffusionUpscalePipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionUpscalePipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/blob/main/x4-upscaler-ema.safetensors"
original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"
repo_id = "stabilityai/stable-diffusion-x4-upscaler"
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_single_file_format_inference_is_same_as_pretrained(self):
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/sd2-upscale/low_res_cat.png"
)
prompt = "a cat sitting on a park bench"
pipe = StableDiffusionUpscalePipeline.from_pretrained(self.repo_id)
pipe.enable_model_cpu_offload()
generator = torch.Generator("cpu").manual_seed(0)
output = pipe(prompt=prompt, image=image, generator=generator, output_type="np", num_inference_steps=3)
image_from_pretrained = output.images[0]
pipe_from_single_file = StableDiffusionUpscalePipeline.from_single_file(self.ckpt_path)
pipe_from_single_file.enable_model_cpu_offload()
generator = torch.Generator("cpu").manual_seed(0)
output_from_single_file = pipe_from_single_file(
prompt=prompt, image=image, generator=generator, output_type="np", num_inference_steps=3
)
image_from_single_file = output_from_single_file.images[0]
assert image_from_pretrained.shape == (512, 512, 3)
assert image_from_single_file.shape == (512, 512, 3)
assert (
numpy_cosine_similarity_distance(image_from_pretrained.flatten(), image_from_single_file.flatten()) < 1e-3
)

View File

@@ -0,0 +1,202 @@
import gc
import tempfile
import unittest
import torch
from diffusers import (
StableDiffusionXLAdapterPipeline,
T2IAdapter,
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
)
from .single_file_testing_utils import (
SDXLSingleFileTesterMixin,
download_diffusers_config,
download_original_config,
download_single_file_checkpoint,
)
enable_full_determinism()
@slow
@require_torch_gpu
class StableDiffusionXLAdapterPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLAdapterPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
original_config = (
"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
)
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self):
prompt = "toy"
generator = torch.Generator(device="cpu").manual_seed(0)
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png"
)
inputs = {
"prompt": prompt,
"image": image,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 7.5,
"output_type": "np",
}
return inputs
def test_single_file_format_inference_is_same_as_pretrained(self):
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16)
pipe_single_file = StableDiffusionXLAdapterPipeline.from_single_file(
self.ckpt_path,
adapter=adapter,
torch_dtype=torch.float16,
safety_checker=None,
)
pipe_single_file.enable_model_cpu_offload()
pipe_single_file.set_progress_bar_config(disable=None)
inputs = self.get_inputs()
images_single_file = pipe_single_file(**inputs).images[0]
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
self.repo_id,
adapter=adapter,
torch_dtype=torch.float16,
safety_checker=None,
)
pipe.enable_model_cpu_offload()
inputs = self.get_inputs()
images = pipe(**inputs).images[0]
assert images_single_file.shape == (768, 512, 3)
assert images.shape == (768, 512, 3)
max_diff = numpy_cosine_similarity_distance(images.flatten(), images_single_file.flatten())
assert max_diff < 5e-3
def test_single_file_components(self):
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16)
pipe = self.pipeline_class.from_pretrained(
self.repo_id,
variant="fp16",
adapter=adapter,
torch_dtype=torch.float16,
)
pipe_single_file = self.pipeline_class.from_single_file(self.ckpt_path, safety_checker=None, adapter=adapter)
super().test_single_file_components(pipe, pipe_single_file)
def test_single_file_components_local_files_only(self):
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16)
pipe = self.pipeline_class.from_pretrained(
self.repo_id,
variant="fp16",
adapter=adapter,
torch_dtype=torch.float16,
)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
single_file_pipe = self.pipeline_class.from_single_file(
local_ckpt_path, adapter=adapter, safety_checker=None, local_files_only=True
)
self._compare_component_configs(pipe, single_file_pipe)
def test_single_file_components_with_diffusers_config(self):
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16)
pipe = self.pipeline_class.from_pretrained(
self.repo_id,
variant="fp16",
adapter=adapter,
torch_dtype=torch.float16,
safety_checker=None,
)
pipe_single_file = self.pipeline_class.from_single_file(self.ckpt_path, config=self.repo_id, adapter=adapter)
self._compare_component_configs(pipe, pipe_single_file)
def test_single_file_components_with_diffusers_config_local_files_only(self):
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16)
pipe = self.pipeline_class.from_pretrained(
self.repo_id,
variant="fp16",
adapter=adapter,
torch_dtype=torch.float16,
)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir)
pipe_single_file = self.pipeline_class.from_single_file(
local_ckpt_path,
config=local_diffusers_config,
adapter=adapter,
safety_checker=None,
local_files_only=True,
)
self._compare_component_configs(pipe, pipe_single_file)
def test_single_file_components_with_original_config(self):
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16)
pipe = self.pipeline_class.from_pretrained(
self.repo_id,
variant="fp16",
adapter=adapter,
torch_dtype=torch.float16,
safety_checker=None,
)
pipe_single_file = self.pipeline_class.from_single_file(
self.ckpt_path, original_config=self.original_config, adapter=adapter
)
self._compare_component_configs(pipe, pipe_single_file)
def test_single_file_components_with_original_config_local_files_only(self):
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16)
pipe = self.pipeline_class.from_pretrained(
self.repo_id,
variant="fp16",
adapter=adapter,
torch_dtype=torch.float16,
)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
local_original_config = download_original_config(self.original_config, tmpdir)
pipe_single_file = self.pipeline_class.from_single_file(
local_ckpt_path,
original_config=local_original_config,
adapter=adapter,
safety_checker=None,
local_files_only=True,
)
self._compare_component_configs(pipe, pipe_single_file)

View File

@@ -0,0 +1,197 @@
import gc
import tempfile
import unittest
import torch
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
)
from .single_file_testing_utils import (
SDXLSingleFileTesterMixin,
download_diffusers_config,
download_single_file_checkpoint,
)
enable_full_determinism()
@slow
@require_torch_gpu
class StableDiffusionXLControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLControlNetPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
original_config = (
"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
)
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png"
)
inputs = {
"prompt": "Stormtrooper's lecture",
"image": image,
"generator": generator,
"num_inference_steps": 2,
"strength": 0.75,
"guidance_scale": 7.5,
"output_type": "np",
}
return inputs
def test_single_file_format_inference_is_same_as_pretrained(self):
controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16)
pipe_single_file = self.pipeline_class.from_single_file(
self.ckpt_path, controlnet=controlnet, torch_dtype=torch.float16
)
pipe_single_file.unet.set_default_attn_processor()
pipe_single_file.enable_model_cpu_offload()
pipe_single_file.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
single_file_images = pipe_single_file(**inputs).images[0]
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet, torch_dtype=torch.float16)
pipe.unet.set_default_attn_processor()
pipe.enable_model_cpu_offload()
inputs = self.get_inputs(torch_device)
images = pipe(**inputs).images[0]
assert images.shape == (512, 512, 3)
assert single_file_images.shape == (512, 512, 3)
max_diff = numpy_cosine_similarity_distance(images[0].flatten(), single_file_images[0].flatten())
assert max_diff < 5e-2
def test_single_file_components(self):
controlnet = ControlNetModel.from_pretrained(
"diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
)
pipe = self.pipeline_class.from_pretrained(
self.repo_id,
variant="fp16",
controlnet=controlnet,
torch_dtype=torch.float16,
)
pipe_single_file = self.pipeline_class.from_single_file(self.ckpt_path, controlnet=controlnet)
super().test_single_file_components(pipe, pipe_single_file)
def test_single_file_components_local_files_only(self):
controlnet = ControlNetModel.from_pretrained(
"diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
)
pipe = self.pipeline_class.from_pretrained(
self.repo_id,
variant="fp16",
controlnet=controlnet,
torch_dtype=torch.float16,
)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
single_file_pipe = self.pipeline_class.from_single_file(
local_ckpt_path, controlnet=controlnet, safety_checker=None, local_files_only=True
)
self._compare_component_configs(pipe, single_file_pipe)
def test_single_file_components_with_original_config(self):
controlnet = ControlNetModel.from_pretrained(
"diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
)
pipe = self.pipeline_class.from_pretrained(
self.repo_id,
variant="fp16",
controlnet=controlnet,
torch_dtype=torch.float16,
)
pipe_single_file = self.pipeline_class.from_single_file(
self.ckpt_path,
original_config=self.original_config,
controlnet=controlnet,
)
self._compare_component_configs(pipe, pipe_single_file)
def test_single_file_components_with_original_config_local_files_only(self):
controlnet = ControlNetModel.from_pretrained(
"diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
)
pipe = self.pipeline_class.from_pretrained(
self.repo_id,
variant="fp16",
controlnet=controlnet,
torch_dtype=torch.float16,
)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
pipe_single_file = self.pipeline_class.from_single_file(
local_ckpt_path,
safety_checker=None,
controlnet=controlnet,
local_files_only=True,
)
self._compare_component_configs(pipe, pipe_single_file)
def test_single_file_components_with_diffusers_config(self):
controlnet = ControlNetModel.from_pretrained(
"diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
)
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
pipe_single_file = self.pipeline_class.from_single_file(
self.ckpt_path, controlnet=controlnet, config=self.repo_id
)
super()._compare_component_configs(pipe, pipe_single_file)
def test_single_file_components_with_diffusers_config_local_files_only(self):
controlnet = ControlNetModel.from_pretrained(
"diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
)
pipe = self.pipeline_class.from_pretrained(
self.repo_id,
controlnet=controlnet,
)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir)
pipe_single_file = self.pipeline_class.from_single_file(
local_ckpt_path,
config=local_diffusers_config,
safety_checker=None,
controlnet=controlnet,
local_files_only=True,
)
super()._compare_component_configs(pipe, pipe_single_file)

View File

@@ -0,0 +1,105 @@
import gc
import unittest
import torch
from diffusers import (
DDIMScheduler,
StableDiffusionXLImg2ImgPipeline,
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
)
from .single_file_testing_utils import SDXLSingleFileTesterMixin
enable_full_determinism()
@slow
@require_torch_gpu
class StableDiffusionXLImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLImg2ImgPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
original_config = (
"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
)
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
init_image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_img2img/sketch-mountains-input.png"
)
inputs = {
"prompt": "a fantasy landscape, concept art, high resolution",
"image": init_image,
"generator": generator,
"num_inference_steps": 3,
"strength": 0.75,
"guidance_scale": 7.5,
"output_type": "np",
}
return inputs
def test_single_file_format_inference_is_same_as_pretrained(self):
super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3)
@slow
@require_torch_gpu
class StableDiffusionXLImg2ImgRefinerPipelineSingleFileSlowTests(unittest.TestCase):
pipeline_class = StableDiffusionXLImg2ImgPipeline
ckpt_path = (
"https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/main/sd_xl_refiner_1.0.safetensors"
)
repo_id = "stabilityai/stable-diffusion-xl-refiner-1.0"
original_config = (
"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
)
def test_single_file_format_inference_is_same_as_pretrained(self):
init_image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
"/stable_diffusion_img2img/sketch-mountains-input.png"
)
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.unet.set_default_attn_processor()
pipe.enable_model_cpu_offload()
generator = torch.Generator(device="cpu").manual_seed(0)
image = pipe(
prompt="mountains", image=init_image, num_inference_steps=5, generator=generator, output_type="np"
).images[0]
pipe_single_file = self.pipeline_class.from_single_file(self.ckpt_path, torch_dtype=torch.float16)
pipe_single_file.scheduler = DDIMScheduler.from_config(pipe_single_file.scheduler.config)
pipe_single_file.unet.set_default_attn_processor()
pipe_single_file.enable_model_cpu_offload()
generator = torch.Generator(device="cpu").manual_seed(0)
image_single_file = pipe_single_file(
prompt="mountains", image=init_image, num_inference_steps=5, generator=generator, output_type="np"
).images[0]
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten())
assert max_diff < 5e-4

View File

@@ -0,0 +1,50 @@
import gc
import unittest
import torch
from diffusers import StableDiffusionXLInstructPix2PixPipeline
from diffusers.utils.testing_utils import (
enable_full_determinism,
require_torch_gpu,
slow,
)
enable_full_determinism()
@slow
@require_torch_gpu
class StableDiffusionXLInstructPix2PixPipeline(unittest.TestCase):
pipeline_class = StableDiffusionXLInstructPix2PixPipeline
ckpt_path = "https://huggingface.co/stabilityai/cosxl/blob/main/cosxl_edit.safetensors"
original_config = None
repo_id = "diffusers/sdxl-instructpix2pix-768"
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
inputs = {
"prompt": "a fantasy landscape, concept art, high resolution",
"generator": generator,
"num_inference_steps": 2,
"strength": 0.75,
"guidance_scale": 7.5,
"output_type": "np",
}
return inputs
def test_single_file_setting_cosxl_edit(self):
# Default is PNDM for this checkpoint
pipe = self.pipeline_class.from_single_file(self.ckpt_path, config=self.repo_id, is_cosxl_edit=True)
assert pipe.is_cosxl_edit is True

View File

@@ -0,0 +1,54 @@
import gc
import unittest
import torch
from diffusers import (
StableDiffusionXLPipeline,
)
from diffusers.utils.testing_utils import (
enable_full_determinism,
require_torch_gpu,
slow,
)
from .single_file_testing_utils import SDXLSingleFileTesterMixin
enable_full_determinism()
@slow
@require_torch_gpu
class StableDiffusionXLPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
original_config = (
"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
)
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
inputs = {
"prompt": "a fantasy landscape, concept art, high resolution",
"generator": generator,
"num_inference_steps": 2,
"strength": 0.75,
"guidance_scale": 7.5,
"output_type": "np",
}
return inputs
def test_single_file_format_inference_is_same_as_pretrained(self):
super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3)