mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[draft v2] AutoPipeline (#4138)
* initial * style * from ...pipelines -> from ..pipeline_util * make style * fix-copies * fix value_guided_sampling oops * style * add test * Show failing test * update from_pipe * fix * add controlnet, additional test and register unused original config * update for controlnet * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * store unused config as private attribute and pass if can * add doc * kandinsky inpaint pipeline does not work with decoder checkpoint * update doc * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * style * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * fix * Apply suggestions from code review --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -182,6 +182,8 @@
|
||||
title: Audio Diffusion
|
||||
- local: api/pipelines/audioldm
|
||||
title: AudioLDM
|
||||
- local: api/pipelines/auto_pipeline
|
||||
title: AutoPipeline
|
||||
- local: api/pipelines/consistency_models
|
||||
title: Consistency Models
|
||||
- local: api/pipelines/controlnet
|
||||
|
||||
68
docs/source/en/api/pipelines/auto_pipeline.mdx
Normal file
68
docs/source/en/api/pipelines/auto_pipeline.mdx
Normal file
@@ -0,0 +1,68 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# AutoPipeline
|
||||
|
||||
In many cases, one checkpoint can be used for multiple tasks. For example, you may be able to use the same checkpoint for Text-to-Image, Image-to-Image, and Inpainting. However, you'll need to know the pipeline class names linked to your checkpoint.
|
||||
|
||||
AutoPipeline is designed to make it easy for you to use multiple pipelines in your workflow. We currently provide 3 AutoPipeline classes to perform three different tasks, i.e. [`AutoPipelineForText2Image`], [`AutoPipelineForImage2Image`], and [`AutoPipelineForInpainting`]. You'll need to choose the AutoPipeline class based on the task you want to perform and use it to automatically retrieve the relevant pipeline given the name/path to the pre-trained weights.
|
||||
|
||||
For example, to perform Image-to-Image with the SD1.5 checkpoint, you can do
|
||||
|
||||
```python
|
||||
from diffusers import PipelineForImageToImage
|
||||
|
||||
pipe_i2i = PipelineForImageoImage.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
```
|
||||
|
||||
It will also help you switch between tasks seamlessly using the same checkpoint without reallocating additional memory. For example, to re-use the Image-to-Image pipeline we just created for inpainting, you can do
|
||||
|
||||
```python
|
||||
from diffusers import PipelineForInpainting
|
||||
|
||||
pipe_inpaint = AutoPipelineForInpainting.from_pipe(pipe_i2i)
|
||||
```
|
||||
All the components will be transferred to the inpainting pipeline with zero cost.
|
||||
|
||||
|
||||
Currently AutoPipeline support the Text-to-Image, Image-to-Image, and Inpainting tasks for below diffusion models:
|
||||
- [stable Diffusion](./stable_diffusion)
|
||||
- [Stable Diffusion Controlnet](./api/pipelines/controlnet)
|
||||
- [Stable Diffusion XL](./stable_diffusion/stable_diffusion_xl)
|
||||
- [IF](./if)
|
||||
- [Kandinsky](./kandinsky)(./kandinsky)(./kandinsky)(./kandinsky)(./kandinsky)
|
||||
- [Kandinsky 2.2]()(./kandinsky)
|
||||
|
||||
|
||||
## AutoPipelineForText2Image
|
||||
|
||||
[[autodoc]] AutoPipelineForText2Image
|
||||
- all
|
||||
- from_pretrained
|
||||
- from_pipe
|
||||
|
||||
|
||||
## AutoPipelineForImage2Image
|
||||
|
||||
[[autodoc]] AutoPipelineForImage2Image
|
||||
- all
|
||||
- from_pretrained
|
||||
- from_pipe
|
||||
|
||||
## AutoPipelineForInpainting
|
||||
|
||||
[[autodoc]] AutoPipelineForInpainting
|
||||
- all
|
||||
- from_pretrained
|
||||
- from_pipe
|
||||
|
||||
|
||||
@@ -62,6 +62,9 @@ else:
|
||||
)
|
||||
from .pipelines import (
|
||||
AudioPipelineOutput,
|
||||
AutoPipelineForImage2Image,
|
||||
AutoPipelineForInpainting,
|
||||
AutoPipelineForText2Image,
|
||||
ConsistencyModelPipeline,
|
||||
DanceDiffusionPipeline,
|
||||
DDIMPipeline,
|
||||
|
||||
@@ -17,6 +17,7 @@ try:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .auto_pipeline import AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image
|
||||
from .consistency_models import ConsistencyModelPipeline
|
||||
from .dance_diffusion import DanceDiffusionPipeline
|
||||
from .ddim import DDIMPipeline
|
||||
|
||||
834
src/diffusers/pipelines/auto_pipeline.py
Normal file
834
src/diffusers/pipelines/auto_pipeline.py
Normal file
@@ -0,0 +1,834 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from collections import OrderedDict
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from .controlnet import (
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
)
|
||||
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
|
||||
from .kandinsky import KandinskyImg2ImgPipeline, KandinskyInpaintPipeline, KandinskyPipeline
|
||||
from .kandinsky2_2 import KandinskyV22Img2ImgPipeline, KandinskyV22InpaintPipeline, KandinskyV22Pipeline
|
||||
from .stable_diffusion import (
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from .stable_diffusion_xl import (
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
|
||||
|
||||
AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
[
|
||||
("stable-diffusion", StableDiffusionPipeline),
|
||||
("stable-diffusion-xl", StableDiffusionXLPipeline),
|
||||
("if", IFPipeline),
|
||||
("kandinsky", KandinskyPipeline),
|
||||
("kandinsky22", KandinskyV22Pipeline),
|
||||
("stable-diffusion-controlnet", StableDiffusionControlNetPipeline),
|
||||
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
[
|
||||
("stable-diffusion", StableDiffusionImg2ImgPipeline),
|
||||
("stable-diffusion-xl", StableDiffusionXLImg2ImgPipeline),
|
||||
("if", IFImg2ImgPipeline),
|
||||
("kandinsky", KandinskyImg2ImgPipeline),
|
||||
("kandinsky22", KandinskyV22Img2ImgPipeline),
|
||||
("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
|
||||
[
|
||||
("stable-diffusion", StableDiffusionInpaintPipeline),
|
||||
("stable-diffusion-xl", StableDiffusionXLInpaintPipeline),
|
||||
("if", IFInpaintingPipeline),
|
||||
("kandinsky", KandinskyInpaintPipeline),
|
||||
("kandinsky22", KandinskyV22InpaintPipeline),
|
||||
("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
SUPPORTED_TASKS_MAPPINGS = [
|
||||
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
|
||||
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
|
||||
AUTO_INPAINT_PIPELINES_MAPPING,
|
||||
]
|
||||
|
||||
|
||||
def _get_task_class(mapping, pipeline_class_name):
|
||||
def get_model(pipeline_class_name):
|
||||
for task_mapping in SUPPORTED_TASKS_MAPPINGS:
|
||||
for model_name, pipeline in task_mapping.items():
|
||||
if pipeline.__name__ == pipeline_class_name:
|
||||
return model_name
|
||||
|
||||
model_name = get_model(pipeline_class_name)
|
||||
|
||||
if model_name is not None:
|
||||
task_class = mapping.get(model_name, None)
|
||||
if task_class is not None:
|
||||
return task_class
|
||||
raise ValueError(f"AutoPipeline can't find a pipeline linked to {pipeline_class_name} for {model_name}")
|
||||
|
||||
|
||||
def _get_signature_keys(obj):
|
||||
parameters = inspect.signature(obj.__init__).parameters
|
||||
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
||||
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
||||
expected_modules = set(required_parameters.keys()) - {"self"}
|
||||
return expected_modules, optional_parameters
|
||||
|
||||
|
||||
class AutoPipelineForText2Image(ConfigMixin):
|
||||
r"""
|
||||
|
||||
AutoPipeline for text-to-image generation.
|
||||
|
||||
[`AutoPipelineForText2Image`] is a generic pipeline class that will be instantiated as one of the text-to-image
|
||||
pipeline class in diffusers.
|
||||
|
||||
The pipeline type (for example [`StableDiffusionPipeline`]) is automatically selected when created with the
|
||||
AutoPipelineForText2Image.from_pretrained(pretrained_model_name_or_path) or
|
||||
AutoPipelineForText2Image.from_pipe(pipeline) class methods .
|
||||
|
||||
This class cannot be instantiated using __init__() (throws an error).
|
||||
|
||||
Class attributes:
|
||||
|
||||
- **config_name** (`str`) -- The configuration filename that stores the class and module names of all the
|
||||
diffusion pipeline's components.
|
||||
|
||||
"""
|
||||
config_name = "model_index.json"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise EnvironmentError(
|
||||
f"{self.__class__.__name__} is designed to be instantiated "
|
||||
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
|
||||
f"`{self.__class__.__name__}.from_pipe(pipeline)` methods."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiates a text-to-image Pytorch diffusion pipeline from pretrained pipeline weight.
|
||||
|
||||
The from_pretrained() method takes care of returning the correct pipeline class instance by:
|
||||
1. Detect the pipeline class of the pretrained_model_or_path based on the _class_name property of its
|
||||
config object
|
||||
2. Find the text-to-image pipeline linked to the pipeline class using pattern matching on pipeline class
|
||||
name.
|
||||
|
||||
If a `controlnet` argument is passed, it will instantiate a [`StableDiffusionControlNetPipeline`] object.
|
||||
|
||||
The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
If you get the error message below, you need to finetune the weights for your downstream task:
|
||||
|
||||
```
|
||||
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
|
||||
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
|
||||
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
||||
```
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
|
||||
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
|
||||
hosted on the Hub.
|
||||
- A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
|
||||
saved using
|
||||
[`~DiffusionPipeline.save_pretrained`].
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
custom_revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
|
||||
`revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a
|
||||
custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub.
|
||||
mirror (`str`, *optional*):
|
||||
Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
|
||||
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
||||
information.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn’t need to be defined for each
|
||||
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
|
||||
same device.
|
||||
|
||||
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
|
||||
more information about each option see [designing a device
|
||||
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
max_memory (`Dict`, *optional*):
|
||||
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
|
||||
each GPU and the available CPU RAM if unset.
|
||||
offload_folder (`str` or `os.PathLike`, *optional*):
|
||||
The path to offload weights if device_map contains the value `"disk"`.
|
||||
offload_state_dict (`bool`, *optional*):
|
||||
If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
|
||||
the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
|
||||
when there is some disk offload.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
|
||||
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
|
||||
below for more information.
|
||||
variant (`str`, *optional*):
|
||||
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
|
||||
loading `from_flax`.
|
||||
|
||||
<Tip>
|
||||
|
||||
To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
|
||||
`huggingface-cli login`.
|
||||
|
||||
</Tip>
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import AutoPipelineForTextToImage
|
||||
|
||||
>>> pipeline = AutoPipelineForTextToImage.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
>>> print(pipeline.__class__)
|
||||
```
|
||||
"""
|
||||
config = cls.load_config(pretrained_model_or_path)
|
||||
orig_class_name = config["_class_name"]
|
||||
|
||||
if "controlnet" in kwargs:
|
||||
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
|
||||
|
||||
text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name)
|
||||
|
||||
return text_2_image_cls.from_pretrained(pretrained_model_or_path, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_pipe(cls, pipeline, **kwargs):
|
||||
r"""
|
||||
Instantiates a text-to-image Pytorch diffusion pipeline from another instantiated diffusion pipeline class.
|
||||
|
||||
The from_pipe() method takes care of returning the correct pipeline class instance by finding the text-to-image
|
||||
pipeline linked to the pipeline class using pattern matching on pipeline class name.
|
||||
|
||||
All the modules the pipeline contains will be used to initialize the new pipeline without reallocating
|
||||
additional memoery.
|
||||
|
||||
The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pipeline (`DiffusionPipeline`):
|
||||
an instantiated `DiffusionPipeline` object
|
||||
|
||||
```py
|
||||
>>> from diffusers import AutoPipelineForTextToImage, AutoPipelineForImageToImage
|
||||
|
||||
>>> pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
|
||||
... "runwayml/stable-diffusion-v1-5", requires_safety_checker=False
|
||||
... )
|
||||
|
||||
>>> pipe_t2i = AutoPipelineForTextToImage.from_pipe(pipe_t2i)
|
||||
```
|
||||
"""
|
||||
|
||||
original_config = dict(pipeline.config)
|
||||
original_cls_name = pipeline.__class__.__name__
|
||||
|
||||
# derive the pipeline class to instantiate
|
||||
text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, original_cls_name)
|
||||
|
||||
# define expected module and optional kwargs given the pipeline signature
|
||||
expected_modules, optional_kwargs = _get_signature_keys(text_2_image_cls)
|
||||
|
||||
pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
|
||||
|
||||
# allow users pass modules in `kwargs` to override the original pipeline's components
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
original_class_obj = {
|
||||
k: pipeline.components[k]
|
||||
for k, v in pipeline.components.items()
|
||||
if k in expected_modules and k not in passed_class_obj
|
||||
}
|
||||
|
||||
# allow users pass optional kwargs to override the original pipelines config attribute
|
||||
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
||||
original_pipe_kwargs = {
|
||||
k: original_config[k]
|
||||
for k, v in original_config.items()
|
||||
if k in optional_kwargs and k not in passed_pipe_kwargs
|
||||
}
|
||||
|
||||
# config that were not expected by original pipeline is stored as private attribute
|
||||
# we will pass them as optional arguments if they can be accepted by the pipeline
|
||||
additional_pipe_kwargs = [
|
||||
k[1:]
|
||||
for k in original_config.keys()
|
||||
if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs
|
||||
]
|
||||
for k in additional_pipe_kwargs:
|
||||
original_pipe_kwargs[k] = original_config.pop(f"_{k}")
|
||||
|
||||
text_2_image_kwargs = {**passed_class_obj, **original_class_obj, **passed_pipe_kwargs, **original_pipe_kwargs}
|
||||
|
||||
# store unused config as private attribute
|
||||
unused_original_config = {
|
||||
f"{'' if k.startswith('_') else '_'}{k}": original_config[k]
|
||||
for k, v in original_config.items()
|
||||
if k not in text_2_image_kwargs
|
||||
}
|
||||
|
||||
missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(text_2_image_kwargs.keys())
|
||||
|
||||
if len(missing_modules) > 0:
|
||||
raise ValueError(
|
||||
f"Pipeline {text_2_image_cls} expected {expected_modules}, but only {set(passed_class_obj.keys()) + set(original_class_obj.keys())} were passed"
|
||||
)
|
||||
|
||||
model = text_2_image_cls(**text_2_image_kwargs)
|
||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||
model.register_to_config(**unused_original_config)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class AutoPipelineForImage2Image(ConfigMixin):
|
||||
r"""
|
||||
|
||||
AutoPipeline for image-to-image generation.
|
||||
|
||||
[`AutoPipelineForImage2Image`] is a generic pipeline class that will be instantiated as one of the image-to-image
|
||||
pipeline classes in diffusers.
|
||||
|
||||
The pipeline type (for example [`StableDiffusionImg2ImgPipeline`]) is automatically selected when created with the
|
||||
`AutoPipelineForImage2Image.from_pretrained(pretrained_model_name_or_path)` or
|
||||
`AutoPipelineForImage2Image.from_pipe(pipeline)` class methods.
|
||||
|
||||
This class cannot be instantiated using __init__() (throws an error).
|
||||
|
||||
Class attributes:
|
||||
|
||||
- **config_name** (`str`) -- The configuration filename that stores the class and module names of all the
|
||||
diffusion pipeline's components.
|
||||
|
||||
"""
|
||||
config_name = "model_index.json"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise EnvironmentError(
|
||||
f"{self.__class__.__name__} is designed to be instantiated "
|
||||
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
|
||||
f"`{self.__class__.__name__}.from_pipe(pipeline)` methods."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiates a image-to-image Pytorch diffusion pipeline from pretrained pipeline weight.
|
||||
|
||||
The from_pretrained() method takes care of returning the correct pipeline class instance by:
|
||||
1. Detect the pipeline class of the pretrained_model_or_path based on the _class_name property of its
|
||||
config object
|
||||
2. Find the image-to-image pipeline linked to the pipeline class using pattern matching on pipeline class
|
||||
name.
|
||||
|
||||
If a `controlnet` argument is passed, it will instantiate a StableDiffusionControlNetImg2ImgPipeline object.
|
||||
|
||||
The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
If you get the error message below, you need to finetune the weights for your downstream task:
|
||||
|
||||
```
|
||||
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
|
||||
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
|
||||
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
||||
```
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
|
||||
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
|
||||
hosted on the Hub.
|
||||
- A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
|
||||
saved using
|
||||
[`~DiffusionPipeline.save_pretrained`].
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
custom_revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
|
||||
`revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a
|
||||
custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub.
|
||||
mirror (`str`, *optional*):
|
||||
Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
|
||||
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
||||
information.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn’t need to be defined for each
|
||||
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
|
||||
same device.
|
||||
|
||||
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
|
||||
more information about each option see [designing a device
|
||||
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
max_memory (`Dict`, *optional*):
|
||||
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
|
||||
each GPU and the available CPU RAM if unset.
|
||||
offload_folder (`str` or `os.PathLike`, *optional*):
|
||||
The path to offload weights if device_map contains the value `"disk"`.
|
||||
offload_state_dict (`bool`, *optional*):
|
||||
If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
|
||||
the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
|
||||
when there is some disk offload.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
|
||||
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
|
||||
below for more information.
|
||||
variant (`str`, *optional*):
|
||||
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
|
||||
loading `from_flax`.
|
||||
|
||||
<Tip>
|
||||
|
||||
To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
|
||||
`huggingface-cli login`.
|
||||
|
||||
</Tip>
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import AutoPipelineForTextToImage
|
||||
|
||||
>>> pipeline = AutoPipelineForImageToImage.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
>>> print(pipeline.__class__)
|
||||
```
|
||||
"""
|
||||
config = cls.load_config(pretrained_model_or_path)
|
||||
orig_class_name = config["_class_name"]
|
||||
|
||||
if "controlnet" in kwargs:
|
||||
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
|
||||
|
||||
image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name)
|
||||
|
||||
return image_2_image_cls.from_pretrained(pretrained_model_or_path, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_pipe(cls, pipeline, **kwargs):
|
||||
r"""
|
||||
Instantiates a image-to-image Pytorch diffusion pipeline from another instantiated diffusion pipeline class.
|
||||
|
||||
The from_pipe() method takes care of returning the correct pipeline class instance by finding the
|
||||
image-to-image pipeline linked to the pipeline class using pattern matching on pipeline class name.
|
||||
|
||||
All the modules the pipeline contains will be used to initialize the new pipeline without reallocating
|
||||
additional memoery.
|
||||
|
||||
The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pipeline (`DiffusionPipeline`):
|
||||
an instantiated `DiffusionPipeline` object
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import AutoPipelineForTextToImage, AutoPipelineForImageToImage
|
||||
|
||||
>>> pipe_t2i = AutoPipelineForText2Image.from_pretrained(
|
||||
... "runwayml/stable-diffusion-v1-5", requires_safety_checker=False
|
||||
... )
|
||||
|
||||
>>> pipe_i2i = AutoPipelineForImageToImage.from_pipe(pipe_t2i)
|
||||
```
|
||||
"""
|
||||
|
||||
original_config = dict(pipeline.config)
|
||||
original_cls_name = pipeline.__class__.__name__
|
||||
|
||||
# derive the pipeline class to instantiate
|
||||
image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, original_cls_name)
|
||||
|
||||
# define expected module and optional kwargs given the pipeline signature
|
||||
expected_modules, optional_kwargs = _get_signature_keys(image_2_image_cls)
|
||||
|
||||
pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
|
||||
|
||||
# allow users pass modules in `kwargs` to override the original pipeline's components
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
original_class_obj = {
|
||||
k: pipeline.components[k]
|
||||
for k, v in pipeline.components.items()
|
||||
if k in expected_modules and k not in passed_class_obj
|
||||
}
|
||||
|
||||
# allow users pass optional kwargs to override the original pipelines config attribute
|
||||
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
||||
original_pipe_kwargs = {
|
||||
k: original_config[k]
|
||||
for k, v in original_config.items()
|
||||
if k in optional_kwargs and k not in passed_pipe_kwargs
|
||||
}
|
||||
|
||||
# config attribute that were not expected by original pipeline is stored as its private attribute
|
||||
# we will pass them as optional arguments if they can be accepted by the pipeline
|
||||
additional_pipe_kwargs = [
|
||||
k[1:]
|
||||
for k in original_config.keys()
|
||||
if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs
|
||||
]
|
||||
for k in additional_pipe_kwargs:
|
||||
original_pipe_kwargs[k] = original_config.pop(f"_{k}")
|
||||
|
||||
image_2_image_kwargs = {**passed_class_obj, **original_class_obj, **passed_pipe_kwargs, **original_pipe_kwargs}
|
||||
|
||||
# store unused config as private attribute
|
||||
unused_original_config = {
|
||||
f"{'' if k.startswith('_') else '_'}{k}": original_config[k]
|
||||
for k, v in original_config.items()
|
||||
if k not in image_2_image_kwargs
|
||||
}
|
||||
|
||||
missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(image_2_image_kwargs.keys())
|
||||
|
||||
if len(missing_modules) > 0:
|
||||
raise ValueError(
|
||||
f"Pipeline {image_2_image_cls} expected {expected_modules}, but only {set(passed_class_obj.keys()) + set(original_class_obj.keys())} were passed"
|
||||
)
|
||||
|
||||
model = image_2_image_cls(**image_2_image_kwargs)
|
||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||
model.register_to_config(**unused_original_config)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class AutoPipelineForInpainting(ConfigMixin):
|
||||
r"""
|
||||
|
||||
AutoPipeline for inpainting generation.
|
||||
|
||||
[`AutoPipelineForInpainting`] is a generic pipeline class that will be instantiated as one of the inpainting
|
||||
pipeline class in diffusers.
|
||||
|
||||
The pipeline type (for example [`IFInpaintingPipeline`]) is automatically selected when created with the
|
||||
AutoPipelineForInpainting.from_pretrained(pretrained_model_name_or_path) or
|
||||
AutoPipelineForInpainting.from_pipe(pipeline) class methods .
|
||||
|
||||
This class cannot be instantiated using __init__() (throws an error).
|
||||
|
||||
Class attributes:
|
||||
|
||||
- **config_name** (`str`) -- The configuration filename that stores the class and module names of all the
|
||||
diffusion pipeline's components.
|
||||
|
||||
"""
|
||||
config_name = "model_index.json"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise EnvironmentError(
|
||||
f"{self.__class__.__name__} is designed to be instantiated "
|
||||
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
|
||||
f"`{self.__class__.__name__}.from_pipe(pipeline)` methods."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiates a inpainting Pytorch diffusion pipeline from pretrained pipeline weight.
|
||||
|
||||
The from_pretrained() method takes care of returning the correct pipeline class instance by:
|
||||
1. Detect the pipeline class of the pretrained_model_or_path based on the _class_name property of its
|
||||
config object
|
||||
2. Find the inpainting pipeline linked to the pipeline class using pattern matching on pipeline class name.
|
||||
|
||||
If a `controlnet` argument is passed, it will instantiate a StableDiffusionControlNetInpaintPipeline object.
|
||||
|
||||
The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
If you get the error message below, you need to finetune the weights for your downstream task:
|
||||
|
||||
```
|
||||
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
|
||||
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
|
||||
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
||||
```
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
|
||||
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
|
||||
hosted on the Hub.
|
||||
- A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
|
||||
saved using
|
||||
[`~DiffusionPipeline.save_pretrained`].
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
custom_revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
|
||||
`revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a
|
||||
custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub.
|
||||
mirror (`str`, *optional*):
|
||||
Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
|
||||
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
||||
information.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn’t need to be defined for each
|
||||
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
|
||||
same device.
|
||||
|
||||
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
|
||||
more information about each option see [designing a device
|
||||
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
max_memory (`Dict`, *optional*):
|
||||
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
|
||||
each GPU and the available CPU RAM if unset.
|
||||
offload_folder (`str` or `os.PathLike`, *optional*):
|
||||
The path to offload weights if device_map contains the value `"disk"`.
|
||||
offload_state_dict (`bool`, *optional*):
|
||||
If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
|
||||
the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
|
||||
when there is some disk offload.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
|
||||
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
|
||||
below for more information.
|
||||
variant (`str`, *optional*):
|
||||
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
|
||||
loading `from_flax`.
|
||||
|
||||
<Tip>
|
||||
|
||||
To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
|
||||
`huggingface-cli login`.
|
||||
|
||||
</Tip>
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import AutoPipelineForTextToImage
|
||||
|
||||
>>> pipeline = AutoPipelineForImageToImage.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
>>> print(pipeline.__class__)
|
||||
```
|
||||
"""
|
||||
config = cls.load_config(pretrained_model_or_path)
|
||||
orig_class_name = config["_class_name"]
|
||||
|
||||
if "controlnet" in kwargs:
|
||||
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
|
||||
|
||||
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
|
||||
|
||||
return inpainting_cls.from_pretrained(pretrained_model_or_path, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_pipe(cls, pipeline, **kwargs):
|
||||
r"""
|
||||
Instantiates a inpainting Pytorch diffusion pipeline from another instantiated diffusion pipeline class.
|
||||
|
||||
The from_pipe() method takes care of returning the correct pipeline class instance by finding the inpainting
|
||||
pipeline linked to the pipeline class using pattern matching on pipeline class name.
|
||||
|
||||
All the modules the pipeline class contain will be used to initialize the new pipeline without reallocating
|
||||
additional memoery.
|
||||
|
||||
The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pipeline (`DiffusionPipeline`):
|
||||
an instantiated `DiffusionPipeline` object
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import AutoPipelineForTextToImage, AutoPipelineForInpainting
|
||||
|
||||
>>> pipe_t2i = AutoPipelineForText2Image.from_pretrained(
|
||||
... "DeepFloyd/IF-I-XL-v1.0", requires_safety_checker=False
|
||||
... )
|
||||
|
||||
>>> pipe_inpaint = AutoPipelineForInpainting.from_pipe(pipe_t2i)
|
||||
```
|
||||
"""
|
||||
original_config = dict(pipeline.config)
|
||||
original_cls_name = pipeline.__class__.__name__
|
||||
|
||||
# derive the pipeline class to instantiate
|
||||
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, original_cls_name)
|
||||
|
||||
# define expected module and optional kwargs given the pipeline signature
|
||||
expected_modules, optional_kwargs = _get_signature_keys(inpainting_cls)
|
||||
|
||||
pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
|
||||
|
||||
# allow users pass modules in `kwargs` to override the original pipeline's components
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
original_class_obj = {
|
||||
k: pipeline.components[k]
|
||||
for k, v in pipeline.components.items()
|
||||
if k in expected_modules and k not in passed_class_obj
|
||||
}
|
||||
|
||||
# allow users pass optional kwargs to override the original pipelines config attribute
|
||||
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
||||
original_pipe_kwargs = {
|
||||
k: original_config[k]
|
||||
for k, v in original_config.items()
|
||||
if k in optional_kwargs and k not in passed_pipe_kwargs
|
||||
}
|
||||
|
||||
# config that were not expected by original pipeline is stored as private attribute
|
||||
# we will pass them as optional arguments if they can be accepted by the pipeline
|
||||
additional_pipe_kwargs = [
|
||||
k[1:]
|
||||
for k in original_config.keys()
|
||||
if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs
|
||||
]
|
||||
for k in additional_pipe_kwargs:
|
||||
original_pipe_kwargs[k] = original_config.pop(f"_{k}")
|
||||
|
||||
inpainting_kwargs = {**passed_class_obj, **original_class_obj, **passed_pipe_kwargs, **original_pipe_kwargs}
|
||||
|
||||
# store unused config as private attribute
|
||||
unused_original_config = {
|
||||
f"{'' if k.startswith('_') else '_'}{k}": original_config[k]
|
||||
for k, v in original_config.items()
|
||||
if k not in inpainting_kwargs
|
||||
}
|
||||
|
||||
missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(inpainting_kwargs.keys())
|
||||
|
||||
if len(missing_modules) > 0:
|
||||
raise ValueError(
|
||||
f"Pipeline {inpainting_cls} expected {expected_modules}, but only {set(passed_class_obj.keys()) + set(original_class_obj.keys())} were passed"
|
||||
)
|
||||
|
||||
model = inpainting_cls(**inpainting_kwargs)
|
||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||
model.register_to_config(**unused_original_config)
|
||||
|
||||
return model
|
||||
@@ -20,8 +20,6 @@ from transformers import (
|
||||
)
|
||||
|
||||
from ...models import UNet2DConditionModel, VQModel
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...pipelines.pipeline_utils import ImagePipelineOutput
|
||||
from ...schedulers import DDIMScheduler, DDPMScheduler
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
@@ -30,6 +28,7 @@ from ...utils import (
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from .text_encoder import MultilingualCLIP
|
||||
|
||||
|
||||
|
||||
@@ -23,8 +23,6 @@ from transformers import (
|
||||
)
|
||||
|
||||
from ...models import UNet2DConditionModel, VQModel
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...pipelines.pipeline_utils import ImagePipelineOutput
|
||||
from ...schedulers import DDIMScheduler
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
@@ -33,6 +31,7 @@ from ...utils import (
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from .text_encoder import MultilingualCLIP
|
||||
|
||||
|
||||
|
||||
@@ -25,8 +25,6 @@ from transformers import (
|
||||
)
|
||||
|
||||
from ...models import UNet2DConditionModel, VQModel
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...pipelines.pipeline_utils import ImagePipelineOutput
|
||||
from ...schedulers import DDIMScheduler
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
@@ -35,6 +33,7 @@ from ...utils import (
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from .text_encoder import MultilingualCLIP
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...models import PriorTransformer
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...schedulers import UnCLIPScheduler
|
||||
from ...utils import (
|
||||
BaseOutput,
|
||||
@@ -29,6 +28,7 @@ from ...utils import (
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -17,8 +17,6 @@ from typing import List, Optional, Union
|
||||
import torch
|
||||
|
||||
from ...models import UNet2DConditionModel, VQModel
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...pipelines.pipeline_utils import ImagePipelineOutput
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
@@ -27,6 +25,7 @@ from ...utils import (
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -17,8 +17,6 @@ from typing import List, Optional, Union
|
||||
import torch
|
||||
|
||||
from ...models import UNet2DConditionModel, VQModel
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...pipelines.pipeline_utils import ImagePipelineOutput
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
@@ -27,6 +25,7 @@ from ...utils import (
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -20,8 +20,6 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
from ...models import UNet2DConditionModel, VQModel
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...pipelines.pipeline_utils import ImagePipelineOutput
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
@@ -30,6 +28,7 @@ from ...utils import (
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -20,8 +20,6 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
from ...models import UNet2DConditionModel, VQModel
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...pipelines.pipeline_utils import ImagePipelineOutput
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
@@ -30,6 +28,7 @@ from ...utils import (
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -22,8 +22,6 @@ import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
|
||||
from ...models import UNet2DConditionModel, VQModel
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...pipelines.pipeline_utils import ImagePipelineOutput
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
@@ -32,6 +30,7 @@ from ...utils import (
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -5,7 +5,6 @@ import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...models import PriorTransformer
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...schedulers import UnCLIPScheduler
|
||||
from ...utils import (
|
||||
logging,
|
||||
@@ -13,6 +12,7 @@ from ...utils import (
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..kandinsky import KandinskyPriorPipelineOutput
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -5,7 +5,6 @@ import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...models import PriorTransformer
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...schedulers import UnCLIPScheduler
|
||||
from ...utils import (
|
||||
logging,
|
||||
@@ -13,6 +12,7 @@ from ...utils import (
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..kandinsky import KandinskyPriorPipelineOutput
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -22,7 +22,6 @@ import torch
|
||||
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from ...models import PriorTransformer
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...schedulers import HeunDiscreteScheduler
|
||||
from ...utils import (
|
||||
BaseOutput,
|
||||
@@ -32,6 +31,7 @@ from ...utils import (
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .renderer import ShapERenderer
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ import torch
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModel
|
||||
|
||||
from ...models import PriorTransformer
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...schedulers import HeunDiscreteScheduler
|
||||
from ...utils import (
|
||||
BaseOutput,
|
||||
@@ -29,6 +28,7 @@ from ...utils import (
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .renderer import ShapERenderer
|
||||
|
||||
|
||||
|
||||
@@ -23,9 +23,9 @@ from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...schedulers import LMSDiscreteScheduler
|
||||
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
|
||||
|
||||
|
||||
@@ -21,10 +21,9 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
|
||||
|
||||
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...pipelines.pipeline_utils import ImagePipelineOutput
|
||||
from ...schedulers import UnCLIPScheduler
|
||||
from ...utils import logging, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from .text_proj import UnCLIPTextProjModel
|
||||
|
||||
|
||||
|
||||
@@ -26,9 +26,9 @@ from transformers import (
|
||||
)
|
||||
|
||||
from ...models import UNet2DConditionModel, UNet2DModel
|
||||
from ...pipelines import DiffusionPipeline, ImagePipelineOutput
|
||||
from ...schedulers import UnCLIPScheduler
|
||||
from ...utils import logging, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from .text_proj import UnCLIPTextProjModel
|
||||
|
||||
|
||||
|
||||
@@ -255,6 +255,51 @@ class AudioPipelineOutput(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoPipelineForImage2Image(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoPipelineForInpainting(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoPipelineForText2Image(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ConsistencyModelPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
201
tests/pipelines/test_pipelines_auto.py
Normal file
201
tests/pipelines/test_pipelines_auto.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
AutoPipelineForImage2Image,
|
||||
AutoPipelineForInpainting,
|
||||
AutoPipelineForText2Image,
|
||||
ControlNetModel,
|
||||
)
|
||||
from diffusers.pipelines.auto_pipeline import (
|
||||
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
|
||||
AUTO_INPAINT_PIPELINES_MAPPING,
|
||||
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
|
||||
)
|
||||
from diffusers.utils import slow
|
||||
|
||||
|
||||
PRETRAINED_MODEL_REPO_MAPPING = OrderedDict(
|
||||
[
|
||||
("stable-diffusion", "runwayml/stable-diffusion-v1-5"),
|
||||
("if", "DeepFloyd/IF-I-XL-v1.0"),
|
||||
("kandinsky", "kandinsky-community/kandinsky-2-1"),
|
||||
("kandinsky22", "kandinsky-community/kandinsky-2-2-decoder"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class AutoPipelineFastTest(unittest.TestCase):
|
||||
def test_from_pipe_consistent(self):
|
||||
pipe = AutoPipelineForText2Image.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe", requires_safety_checker=False
|
||||
)
|
||||
original_config = dict(pipe.config)
|
||||
|
||||
pipe = AutoPipelineForImage2Image.from_pipe(pipe)
|
||||
assert dict(pipe.config) == original_config
|
||||
|
||||
pipe = AutoPipelineForText2Image.from_pipe(pipe)
|
||||
assert dict(pipe.config) == original_config
|
||||
|
||||
def test_from_pipe_override(self):
|
||||
pipe = AutoPipelineForText2Image.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe", requires_safety_checker=False
|
||||
)
|
||||
|
||||
pipe = AutoPipelineForImage2Image.from_pipe(pipe, requires_safety_checker=True)
|
||||
assert pipe.config.requires_safety_checker is True
|
||||
|
||||
pipe = AutoPipelineForText2Image.from_pipe(pipe, requires_safety_checker=True)
|
||||
assert pipe.config.requires_safety_checker is True
|
||||
|
||||
def test_from_pipe_consistent_sdxl(self):
|
||||
pipe = AutoPipelineForImage2Image.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-xl-pipe",
|
||||
requires_aesthetics_score=True,
|
||||
force_zeros_for_empty_prompt=False,
|
||||
)
|
||||
|
||||
original_config = dict(pipe.config)
|
||||
|
||||
pipe = AutoPipelineForText2Image.from_pipe(pipe)
|
||||
pipe = AutoPipelineForImage2Image.from_pipe(pipe)
|
||||
|
||||
assert dict(pipe.config) == original_config
|
||||
|
||||
|
||||
@slow
|
||||
class AutoPipelineIntegrationTest(unittest.TestCase):
|
||||
def test_pipe_auto(self):
|
||||
for model_name, model_repo in PRETRAINED_MODEL_REPO_MAPPING.items():
|
||||
# test txt2img
|
||||
pipe_txt2img = AutoPipelineForText2Image.from_pretrained(
|
||||
model_repo, variant="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
self.assertIsInstance(pipe_txt2img, AUTO_TEXT2IMAGE_PIPELINES_MAPPING[model_name])
|
||||
|
||||
pipe_to = AutoPipelineForText2Image.from_pipe(pipe_txt2img)
|
||||
self.assertIsInstance(pipe_to, AUTO_TEXT2IMAGE_PIPELINES_MAPPING[model_name])
|
||||
|
||||
pipe_to = AutoPipelineForImage2Image.from_pipe(pipe_txt2img)
|
||||
self.assertIsInstance(pipe_to, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING[model_name])
|
||||
|
||||
if "kandinsky" not in model_name:
|
||||
pipe_to = AutoPipelineForInpainting.from_pipe(pipe_txt2img)
|
||||
self.assertIsInstance(pipe_to, AUTO_INPAINT_PIPELINES_MAPPING[model_name])
|
||||
|
||||
del pipe_txt2img, pipe_to
|
||||
gc.collect()
|
||||
|
||||
# test img2img
|
||||
|
||||
pipe_img2img = AutoPipelineForImage2Image.from_pretrained(
|
||||
model_repo, variant="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
self.assertIsInstance(pipe_img2img, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING[model_name])
|
||||
|
||||
pipe_to = AutoPipelineForText2Image.from_pipe(pipe_img2img)
|
||||
self.assertIsInstance(pipe_to, AUTO_TEXT2IMAGE_PIPELINES_MAPPING[model_name])
|
||||
|
||||
pipe_to = AutoPipelineForImage2Image.from_pipe(pipe_img2img)
|
||||
self.assertIsInstance(pipe_to, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING[model_name])
|
||||
|
||||
if "kandinsky" not in model_name:
|
||||
pipe_to = AutoPipelineForInpainting.from_pipe(pipe_img2img)
|
||||
self.assertIsInstance(pipe_to, AUTO_INPAINT_PIPELINES_MAPPING[model_name])
|
||||
|
||||
del pipe_img2img, pipe_to
|
||||
gc.collect()
|
||||
|
||||
# test inpaint
|
||||
|
||||
if "kandinsky" not in model_name:
|
||||
pipe_inpaint = AutoPipelineForInpainting.from_pretrained(
|
||||
model_repo, variant="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
self.assertIsInstance(pipe_inpaint, AUTO_INPAINT_PIPELINES_MAPPING[model_name])
|
||||
|
||||
pipe_to = AutoPipelineForText2Image.from_pipe(pipe_inpaint)
|
||||
self.assertIsInstance(pipe_to, AUTO_TEXT2IMAGE_PIPELINES_MAPPING[model_name])
|
||||
|
||||
pipe_to = AutoPipelineForImage2Image.from_pipe(pipe_inpaint)
|
||||
self.assertIsInstance(pipe_to, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING[model_name])
|
||||
|
||||
pipe_to = AutoPipelineForInpainting.from_pipe(pipe_inpaint)
|
||||
self.assertIsInstance(pipe_to, AUTO_INPAINT_PIPELINES_MAPPING[model_name])
|
||||
|
||||
del pipe_inpaint, pipe_to
|
||||
gc.collect()
|
||||
|
||||
def test_from_pipe_consistent(self):
|
||||
for model_name, model_repo in PRETRAINED_MODEL_REPO_MAPPING.items():
|
||||
if model_name in ["kandinsky", "kandinsky22"]:
|
||||
auto_pipes = [AutoPipelineForText2Image, AutoPipelineForImage2Image]
|
||||
else:
|
||||
auto_pipes = [AutoPipelineForText2Image, AutoPipelineForImage2Image, AutoPipelineForInpainting]
|
||||
|
||||
# test from_pretrained
|
||||
for pipe_from_class in auto_pipes:
|
||||
pipe_from = pipe_from_class.from_pretrained(model_repo, variant="fp16", torch_dtype=torch.float16)
|
||||
pipe_from_config = dict(pipe_from.config)
|
||||
|
||||
for pipe_to_class in auto_pipes:
|
||||
pipe_to = pipe_to_class.from_pipe(pipe_from)
|
||||
self.assertEqual(dict(pipe_to.config), pipe_from_config)
|
||||
|
||||
del pipe_from, pipe_to
|
||||
gc.collect()
|
||||
|
||||
def test_controlnet(self):
|
||||
# test from_pretrained
|
||||
model_repo = "runwayml/stable-diffusion-v1-5"
|
||||
controlnet_repo = "lllyasviel/sd-controlnet-canny"
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained(controlnet_repo, torch_dtype=torch.float16)
|
||||
|
||||
pipe_txt2img = AutoPipelineForText2Image.from_pretrained(
|
||||
model_repo, controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
self.assertIsInstance(pipe_txt2img, AUTO_TEXT2IMAGE_PIPELINES_MAPPING["stable-diffusion-controlnet"])
|
||||
|
||||
pipe_img2img = AutoPipelineForImage2Image.from_pretrained(
|
||||
model_repo, controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
self.assertIsInstance(pipe_img2img, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["stable-diffusion-controlnet"])
|
||||
|
||||
pipe_inpaint = AutoPipelineForInpainting.from_pretrained(
|
||||
model_repo, controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
self.assertIsInstance(pipe_inpaint, AUTO_INPAINT_PIPELINES_MAPPING["stable-diffusion-controlnet"])
|
||||
|
||||
# test from_pipe
|
||||
for pipe_from in [pipe_txt2img, pipe_img2img, pipe_inpaint]:
|
||||
pipe_to = AutoPipelineForText2Image.from_pipe(pipe_from)
|
||||
self.assertIsInstance(pipe_to, AUTO_TEXT2IMAGE_PIPELINES_MAPPING["stable-diffusion-controlnet"])
|
||||
self.assertEqual(dict(pipe_to.config), dict(pipe_txt2img.config))
|
||||
|
||||
pipe_to = AutoPipelineForImage2Image.from_pipe(pipe_from)
|
||||
self.assertIsInstance(pipe_to, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["stable-diffusion-controlnet"])
|
||||
self.assertEqual(dict(pipe_to.config), dict(pipe_img2img.config))
|
||||
|
||||
pipe_to = AutoPipelineForInpainting.from_pipe(pipe_from)
|
||||
self.assertIsInstance(pipe_to, AUTO_INPAINT_PIPELINES_MAPPING["stable-diffusion-controlnet"])
|
||||
self.assertEqual(dict(pipe_to.config), dict(pipe_inpaint.config))
|
||||
Reference in New Issue
Block a user