1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add Scheduler.from_pretrained and better scheduler changing (#1286)

* add conversion script for vae

* uP

* uP

* more changes

* push

* up

* finish again

* up

* up

* up

* up

* finish

* up

* uP

* up

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Anton Lozhkov <anton@huggingface.co>
Co-authored-by: Suraj Patil <surajp815@gmail.com>

* up

* up

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Anton Lozhkov <anton@huggingface.co>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Patrick von Platen
2022-11-15 18:15:13 +01:00
committed by GitHub
parent db1cb0b1a2
commit a0520193e1
55 changed files with 1149 additions and 407 deletions

View File

@@ -152,15 +152,7 @@ it before the pipeline and pass it to `from_pretrained`.
```python
from diffusers import LMSDiscreteScheduler
lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
revision="fp16",
torch_dtype=torch.float16,
scheduler=lms,
)
pipe = pipe.to("cuda")
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]

View File

@@ -10,6 +10,8 @@
- sections:
- local: using-diffusers/loading
title: "Loading Pipelines, Models, and Schedulers"
- local: using-diffusers/schedulers
title: "Using different Schedulers"
- local: using-diffusers/configuration
title: "Configuring Pipelines, Models, and Schedulers"
- local: using-diffusers/custom_pipeline_overview

View File

@@ -15,9 +15,9 @@ specific language governing permissions and limitations under the License.
In Diffusers, schedulers of type [`schedulers.scheduling_utils.SchedulerMixin`], and models of type [`ModelMixin`] inherit from [`ConfigMixin`] which conveniently takes care of storing all parameters that are
passed to the respective `__init__` methods in a JSON-configuration file.
TODO(PVP) - add example and better info here
## ConfigMixin
[[autodoc]] ConfigMixin
- load_config
- from_config
- save_config

View File

@@ -39,7 +39,7 @@ from diffusers import CycleDiffusionPipeline, DDIMScheduler
# load the pipeline
# make sure you're logged in with `huggingface-cli login`
model_id_or_path = "CompVis/stable-diffusion-v1-4"
scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler")
scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler")
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")
# let's download an initial image

View File

@@ -54,7 +54,7 @@ original_image = download_image(img_url).resize((256, 256))
mask_image = download_image(mask_url).resize((256, 256))
# Load the RePaint scheduler and pipeline based on a pretrained DDPM model
scheduler = RePaintScheduler.from_config("google/ddpm-ema-celebahq-256")
scheduler = RePaintScheduler.from_pretrained("google/ddpm-ema-celebahq-256")
pipe = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=scheduler)
pipe = pipe.to("cuda")

View File

@@ -34,13 +34,17 @@ For more details about how Stable Diffusion works and how it differs from the ba
### How to load and use different schedulers.
The stable diffusion pipeline uses [`PNDMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
To use a different scheduler, you can pass the `scheduler` argument to `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
```python
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
>>> from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
euler_scheduler = EulerDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=euler_scheduler)
>>> pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
>>> # or
>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
>>> pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=euler_scheduler)
```

View File

@@ -41,7 +41,7 @@ In this guide though, you'll use [`DiffusionPipeline`] for text-to-image generat
```python
>>> from diffusers import DiffusionPipeline
>>> generator = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
>>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
```
The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components.
@@ -49,13 +49,13 @@ Because the model consists of roughly 1.4 billion parameters, we strongly recomm
You can move the generator object to GPU, just like you would in PyTorch.
```python
>>> generator.to("cuda")
>>> pipeline.to("cuda")
```
Now you can use the `generator` on your text prompt:
Now you can use the `pipeline` on your text prompt:
```python
>>> image = generator("An image of a squirrel in Picasso style").images[0]
>>> image = pipeline("An image of a squirrel in Picasso style").images[0]
```
The output is by default wrapped into a [PIL Image object](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class).
@@ -82,7 +82,7 @@ just like we did before only that now you need to pass your `AUTH_TOKEN`:
```python
>>> from diffusers import DiffusionPipeline
>>> generator = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN)
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN)
```
If you do not pass your authentication token you will see that the diffusion system will not be correctly
@@ -102,7 +102,7 @@ token. Assuming that `"./stable-diffusion-v1-5"` is the local path to the cloned
you can also load the pipeline as follows:
```python
>>> generator = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5")
>>> pipeline = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5")
```
Running the pipeline is then identical to the code above as it's the same model architecture.
@@ -115,19 +115,20 @@ Running the pipeline is then identical to the code above as it's the same model
Diffusion systems can be used with multiple different [schedulers](./api/schedulers) each with their
pros and cons. By default, Stable Diffusion runs with [`PNDMScheduler`], but it's very simple to
use a different scheduler. *E.g.* if you would instead like to use the [`LMSDiscreteScheduler`] scheduler,
use a different scheduler. *E.g.* if you would instead like to use the [`EulerDiscreteScheduler`] scheduler,
you could use it as follows:
```python
>>> from diffusers import LMSDiscreteScheduler
>>> from diffusers import EulerDiscreteScheduler
>>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
>>> pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN)
>>> generator = StableDiffusionPipeline.from_pretrained(
... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, use_auth_token=AUTH_TOKEN
... )
>>> # change scheduler to Euler
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
```
For more in-detail information on how to change between schedulers, please refer to the [Using Schedulers](./using-diffusers/schedulers) guide.
[Stability AI's](https://stability.ai/) Stable Diffusion model is an impressive image generation model
and can do much more than just generating images from text. We have dedicated a whole documentation page,
just for Stable Diffusion [here](./conceptual/stable_diffusion).

View File

@@ -19,7 +19,7 @@ In the following we explain in-detail how to easily load:
- *Complete Diffusion Pipelines* via the [`DiffusionPipeline.from_pretrained`]
- *Diffusion Models* via [`ModelMixin.from_pretrained`]
- *Schedulers* via [`ConfigMixin.from_config`]
- *Schedulers* via [`SchedulerMixin.from_pretrained`]
## Loading pipelines
@@ -137,15 +137,15 @@ from diffusers import DiffusionPipeline, EulerDiscreteScheduler, DPMSolverMultis
repo_id = "runwayml/stable-diffusion-v1-5"
scheduler = EulerDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
scheduler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
# or
# scheduler = DPMSolverMultistepScheduler.from_config(repo_id, subfolder="scheduler")
# scheduler = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler")
stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, scheduler=scheduler)
```
Three things are worth paying attention to here.
- First, the scheduler is loaded with [`ConfigMixin.from_config`] since it only depends on a configuration file and not any parameterized weights
- First, the scheduler is loaded with [`SchedulerMixin.from_pretrained`]
- Second, the scheduler is loaded with a function argument, called `subfolder="scheduler"` as the configuration of stable diffusion's scheduling is defined in a [subfolder of the official pipeline repository](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/scheduler)
- Third, the scheduler instance can simply be passed with the `scheduler` keyword argument to [`DiffusionPipeline.from_pretrained`]. This works because the [`StableDiffusionPipeline`] defines its scheduler with the `scheduler` attribute. It's not possible to use a different name, such as `sampler=scheduler` since `sampler` is not a defined keyword for [`StableDiffusionPipeline.__init__`]
@@ -337,8 +337,8 @@ model = UNet2DModel.from_pretrained(repo_id)
## Loading schedulers
Schedulers cannot be loaded via a `from_pretrained` method, but instead rely on [`ConfigMixin.from_config`]. Schedulers are **not parameterized** or **trained**, but instead purely defined by a configuration file.
Therefore the loading method was given a different name here.
Schedulers rely on [`SchedulerMixin.from_pretrained`]. Schedulers are **not parameterized** or **trained**, but instead purely defined by a configuration file.
For consistency, we use the same method name as we do for models or pipelines, but no weights are loaded in this case.
In constrast to pipelines or models, loading schedulers does not consume any significant amount of memory and the same configuration file can often be used for a variety of different schedulers.
For example, all of:
@@ -367,13 +367,13 @@ from diffusers import (
repo_id = "runwayml/stable-diffusion-v1-5"
ddpm = DDPMScheduler.from_config(repo_id, subfolder="scheduler")
ddim = DDIMScheduler.from_config(repo_id, subfolder="scheduler")
pndm = PNDMScheduler.from_config(repo_id, subfolder="scheduler")
lms = LMSDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
euler_anc = EulerAncestralDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
euler = EulerDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
dpm = DPMSolverMultistepScheduler.from_config(repo_id, subfolder="scheduler")
ddpm = DDPMScheduler.from_pretrained(repo_id, subfolder="scheduler")
ddim = DDIMScheduler.from_pretrained(repo_id, subfolder="scheduler")
pndm = PNDMScheduler.from_pretrained(repo_id, subfolder="scheduler")
lms = LMSDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
euler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
dpm = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler")
# replace `dpm` with any of `ddpm`, `ddim`, `pndm`, `lms`, `euler`, `euler_anc`
pipeline = StableDiffusionPipeline.from_pretrained(repo_id, scheduler=dpm)

View File

@@ -0,0 +1,262 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Schedulers
Diffusion pipelines are inherently a collection of diffusion models and schedulers that are partly independent from each other. This means that one is able to switch out parts of the pipeline to better customize
a pipeline to one's use case. The best example of this are the [Schedulers](../api/schedulers.mdx).
Whereas diffusion models usually simply define the forward pass from noise to a less noisy sample,
schedulers define the whole denoising process, *i.e.*:
- How many denoising steps?
- Stochastic or deterministic?
- What algorithm to use to find the denoised sample
They can be quite complex and often define a trade-off between **denoising speed** and **denoising quality**.
It is extremely difficult to measure quantitatively which scheduler works best for a given diffusion pipeline, so it is often recommended to simply try out which works best.
The following paragraphs shows how to do so with the 🧨 Diffusers library.
## Load pipeline
Let's start by loading the stable diffusion pipeline.
Remember that you have to be a registered user on the 🤗 Hugging Face Hub, and have "click-accepted" the [license](https://huggingface.co/runwayml/stable-diffusion-v1-5) in order to use stable diffusion.
```python
from huggingface_hub import login
from diffusers import DiffusionPipeline
import torch
# first we need to login with our access token
login()
# Now we can download the pipeline
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
```
Next, we move it to GPU:
```python
pipeline.to("cuda")
```
## Access the scheduler
The scheduler is always one of the components of the pipeline and is usually called `"scheduler"`.
So it can be accessed via the `"scheduler"` property.
```python
pipeline.scheduler
```
**Output**:
```
PNDMScheduler {
"_class_name": "PNDMScheduler",
"_diffusers_version": "0.8.0.dev0",
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"clip_sample": false,
"num_train_timesteps": 1000,
"set_alpha_to_one": false,
"skip_prk_steps": true,
"steps_offset": 1,
"trained_betas": null
}
```
We can see that the scheduler is of type [`PNDMScheduler`].
Cool, now let's compare the scheduler in its performance to other schedulers.
First we define a prompt on which we will test all the different schedulers:
```python
prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition."
```
Next, we create a generator from a random seed that will ensure that we can generate similar images as well as run the pipeline:
```python
generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator).images[0]
image
```
<p align="center">
<br>
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_pndm.png" width="400"/>
<br>
</p>
## Changing the scheduler
Now we show how easy it is to change the scheduler of a pipeline. Every scheduler has a property [`SchedulerMixin.compatibles`]
which defines all compatible schedulers. You can take a look at all available, compatible schedulers for the Stable Diffusion pipeline as follows.
```python
pipeline.scheduler.compatibles
```
**Output**:
```
[diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,
diffusers.schedulers.scheduling_ddim.DDIMScheduler,
diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,
diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,
diffusers.schedulers.scheduling_pndm.PNDMScheduler,
diffusers.schedulers.scheduling_ddpm.DDPMScheduler,
diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler]
```
Cool, lots of schedulers to look at. Feel free to have a look at their respective class definitions:
- [`LMSDiscreteScheduler`],
- [`DDIMScheduler`],
- [`DPMSolverMultistepScheduler`],
- [`EulerDiscreteScheduler`],
- [`PNDMScheduler`],
- [`DDPMScheduler`],
- [`EulerAncestralDiscreteScheduler`].
We will now compare the input prompt with all other schedulers. To change the scheduler of the pipeline you can make use of the
convenient [`ConfigMixin.config`] property in combination with the [`ConfigMixin.from_config`] function.
```python
pipeline.scheduler.config
```
returns a dictionary of the configuration of the scheduler:
**Output**:
```
FrozenDict([('num_train_timesteps', 1000),
('beta_start', 0.00085),
('beta_end', 0.012),
('beta_schedule', 'scaled_linear'),
('trained_betas', None),
('skip_prk_steps', True),
('set_alpha_to_one', False),
('steps_offset', 1),
('_class_name', 'PNDMScheduler'),
('_diffusers_version', '0.8.0.dev0'),
('clip_sample', False)])
```
This configuration can then be used to instantiate a scheduler
of a different class that is compatible with the pipeline. Here,
we change the scheduler to the [`DDIMScheduler`].
```python
from diffusers import DDIMScheduler
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
```
Cool, now we can run the pipeline again to compare the generation quality.
```python
generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator).images[0]
image
```
<p align="center">
<br>
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_ddim.png" width="400"/>
<br>
</p>
## Compare schedulers
So far we have tried running the stable diffusion pipeline with two schedulers: [`PNDMScheduler`] and [`DDIMScheduler`].
A number of better schedulers have been released that can be run with much fewer steps, let's compare them here:
[`LMSDiscreteScheduler`] usually leads to better results:
```python
from diffusers import LMSDiscreteScheduler
pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator).images[0]
image
```
<p align="center">
<br>
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_lms.png" width="400"/>
<br>
</p>
[`EulerDiscreteScheduler`] and [`EulerAncestralDiscreteScheduler`] can generate high quality results with as little as 30 steps.
```python
from diffusers import EulerDiscreteScheduler
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator, num_inference_steps=30).images[0]
image
```
<p align="center">
<br>
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_euler_discrete.png" width="400"/>
<br>
</p>
and:
```python
from diffusers import EulerAncestralDiscreteScheduler
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator, num_inference_steps=30).images[0]
image
```
<p align="center">
<br>
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_euler_ancestral.png" width="400"/>
<br>
</p>
At the time of writing this doc [`DPMSolverMultistepScheduler`] gives arguably the best speed/quality trade-off and can be run with as little
as 20 steps.
```python
from diffusers import DPMSolverMultistepScheduler
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]
image
```
<p align="center">
<br>
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_dpm.png" width="400"/>
<br>
</p>
As you can see most images look very similar and are arguably of very similar quality. It often really depends on the specific use case which scheduler to choose. A good approach is always to run multiple different
schedulers to compare results.

View File

@@ -29,7 +29,7 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R
from requests import HTTPError
from . import __version__
from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, logging
logger = logging.get_logger(__name__)
@@ -37,6 +37,38 @@ logger = logging.get_logger(__name__)
_re_configuration_file = re.compile(r"config\.(.*)\.json")
class FrozenDict(OrderedDict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
for key, value in self.items():
setattr(self, key, value)
self.__frozen = True
def __delitem__(self, *args, **kwargs):
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
def setdefault(self, *args, **kwargs):
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
def pop(self, *args, **kwargs):
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
def update(self, *args, **kwargs):
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
def __setattr__(self, name, value):
if hasattr(self, "__frozen") and self.__frozen:
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
super().__setattr__(name, value)
def __setitem__(self, name, value):
if hasattr(self, "__frozen") and self.__frozen:
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
super().__setitem__(name, value)
class ConfigMixin:
r"""
Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
@@ -49,13 +81,12 @@ class ConfigMixin:
[`~ConfigMixin.save_config`] (should be overridden by parent class).
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
overridden by parent class).
- **_compatible_classes** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
`from_config` can be used from a class different than the one used to save the config (should be overridden
by parent class).
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by parent
class).
"""
config_name = None
ignore_for_config = []
_compatible_classes = []
has_compatibles = False
def register_to_config(self, **kwargs):
if self.config_name is None:
@@ -104,9 +135,98 @@ class ConfigMixin:
logger.info(f"Configuration saved in {output_config_file}")
@classmethod
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
r"""
Instantiate a Python class from a pre-defined JSON-file.
Instantiate a Python class from a config dictionary
Parameters:
config (`Dict[str, Any]`):
A config dictionary from which the Python class will be instantiated. Make sure to only load
configuration files of compatible classes.
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
Whether kwargs that are not consumed by the Python class should be returned or not.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the Python class.
`**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually
overwrite same named arguments of `config`.
Examples:
```python
>>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
>>> # Download scheduler from huggingface.co and cache.
>>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
>>> # Instantiate DDIM scheduler class with same config as DDPM
>>> scheduler = DDIMScheduler.from_config(scheduler.config)
>>> # Instantiate PNDM scheduler class with same config as DDPM
>>> scheduler = PNDMScheduler.from_config(scheduler.config)
```
"""
# <===== TO BE REMOVED WITH DEPRECATION
# TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
if "pretrained_model_name_or_path" in kwargs:
config = kwargs.pop("pretrained_model_name_or_path")
if config is None:
raise ValueError("Please make sure to provide a config as the first positional argument.")
# ======>
if not isinstance(config, dict):
deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
if "Scheduler" in cls.__name__:
deprecation_message += (
f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
" Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
" be removed in v1.0.0."
)
elif "Model" in cls.__name__:
deprecation_message += (
f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
" instead. This functionality will be removed in v1.0.0."
)
deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
# Allow dtype to be specified on initialization
if "dtype" in unused_kwargs:
init_dict["dtype"] = unused_kwargs.pop("dtype")
# Return model and optionally state and/or unused_kwargs
model = cls(**init_dict)
# make sure to also save config parameters that might be used for compatible classes
model.register_to_config(**hidden_dict)
# add hidden kwargs of compatible classes to unused_kwargs
unused_kwargs = {**unused_kwargs, **hidden_dict}
if return_unused_kwargs:
return (model, unused_kwargs)
else:
return model
@classmethod
def get_config_dict(cls, *args, **kwargs):
deprecation_message = (
f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
" removed in version v1.0.0"
)
deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
return cls.load_config(*args, **kwargs)
@classmethod
def load_config(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
r"""
Instantiate a Python class from a config dictionary
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
@@ -120,10 +240,6 @@ class ConfigMixin:
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
checkpoint with 3 labels).
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
@@ -161,33 +277,7 @@ class ConfigMixin:
use this method in a firewalled environment.
</Tip>
"""
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
# Allow dtype to be specified on initialization
if "dtype" in unused_kwargs:
init_dict["dtype"] = unused_kwargs.pop("dtype")
# Return model and optionally state and/or unused_kwargs
model = cls(**init_dict)
return_tuple = (model,)
# Flax schedulers have a state, so return it.
if cls.__name__.startswith("Flax") and hasattr(model, "create_state") and getattr(model, "has_state", False):
state = model.create_state()
return_tuple += (state,)
if return_unused_kwargs:
return return_tuple + (unused_kwargs,)
else:
return return_tuple if len(return_tuple) > 1 else model
@classmethod
def get_config_dict(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
@@ -283,6 +373,9 @@ class ConfigMixin:
except (json.JSONDecodeError, UnicodeDecodeError):
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
if return_unused_kwargs:
return config_dict, kwargs
return config_dict
@staticmethod
@@ -291,6 +384,9 @@ class ConfigMixin:
@classmethod
def extract_init_dict(cls, config_dict, **kwargs):
# 0. Copy origin config dict
original_dict = {k: v for k, v in config_dict.items()}
# 1. Retrieve expected config attributes from __init__ signature
expected_keys = cls._get_init_keys(cls)
expected_keys.remove("self")
@@ -310,10 +406,11 @@ class ConfigMixin:
# load diffusers library to import compatible and original scheduler
diffusers_library = importlib.import_module(__name__.split(".")[0])
# remove attributes from compatible classes that orig cannot expect
compatible_classes = [getattr(diffusers_library, c, None) for c in cls._compatible_classes]
# filter out None potentially undefined dummy classes
compatible_classes = [c for c in compatible_classes if c is not None]
if cls.has_compatibles:
compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
else:
compatible_classes = []
expected_keys_comp_cls = set()
for c in compatible_classes:
expected_keys_c = cls._get_init_keys(c)
@@ -364,7 +461,10 @@ class ConfigMixin:
# 6. Define unused keyword arguments
unused_kwargs = {**config_dict, **kwargs}
return init_dict, unused_kwargs
# 7. Define "hidden" config parameters that were saved for compatible classes
hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict and not k.startswith("_")}
return init_dict, unused_kwargs, hidden_config_dict
@classmethod
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
@@ -377,6 +477,12 @@ class ConfigMixin:
@property
def config(self) -> Dict[str, Any]:
"""
Returns the config of the class as a frozen dictionary
Returns:
`Dict[str, Any]`: Config of the class.
"""
return self._internal_dict
def to_json_string(self) -> str:
@@ -401,38 +507,6 @@ class ConfigMixin:
writer.write(self.to_json_string())
class FrozenDict(OrderedDict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
for key, value in self.items():
setattr(self, key, value)
self.__frozen = True
def __delitem__(self, *args, **kwargs):
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
def setdefault(self, *args, **kwargs):
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
def pop(self, *args, **kwargs):
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
def update(self, *args, **kwargs):
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
def __setattr__(self, name, value):
if hasattr(self, "__frozen") and self.__frozen:
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
super().__setattr__(name, value)
def __setitem__(self, name, value):
if hasattr(self, "__frozen") and self.__frozen:
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
super().__setitem__(name, value)
def register_to_config(init):
r"""
Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are

View File

@@ -47,7 +47,7 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES = {
"diffusers": {
"FlaxModelMixin": ["save_pretrained", "from_pretrained"],
"FlaxSchedulerMixin": ["save_config", "from_config"],
"FlaxSchedulerMixin": ["save_pretrained", "from_pretrained"],
"FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"],
},
"transformers": {
@@ -280,7 +280,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
>>> from diffusers import FlaxDPMSolverMultistepScheduler
>>> model_id = "runwayml/stable-diffusion-v1-5"
>>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_config(
>>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
... model_id,
... subfolder="scheduler",
... )
@@ -303,7 +303,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path):
config_dict = cls.get_config_dict(
config_dict = cls.load_config(
pretrained_model_name_or_path,
cache_dir=cache_dir,
resume_download=resume_download,
@@ -349,7 +349,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
else:
cached_folder = pretrained_model_name_or_path
config_dict = cls.get_config_dict(cached_folder)
config_dict = cls.load_config(cached_folder)
# 2. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
@@ -370,7 +370,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {}

View File

@@ -65,7 +65,7 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES = {
"diffusers": {
"ModelMixin": ["save_pretrained", "from_pretrained"],
"SchedulerMixin": ["save_config", "from_config"],
"SchedulerMixin": ["save_pretrained", "from_pretrained"],
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
},
@@ -207,7 +207,7 @@ class DiffusionPipeline(ConfigMixin):
if torch_device is None:
return self
module_names, _ = self.extract_init_dict(dict(self.config))
module_names, _, _ = self.extract_init_dict(dict(self.config))
for name in module_names.keys():
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
@@ -228,7 +228,7 @@ class DiffusionPipeline(ConfigMixin):
Returns:
`torch.device`: The torch device on which the pipeline is located.
"""
module_names, _ = self.extract_init_dict(dict(self.config))
module_names, _, _ = self.extract_init_dict(dict(self.config))
for name in module_names.keys():
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
@@ -377,11 +377,11 @@ class DiffusionPipeline(ConfigMixin):
>>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> # Download pipeline, but overwrite scheduler
>>> # Use a different scheduler
>>> from diffusers import LMSDiscreteScheduler
>>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler)
>>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
>>> pipeline.scheduler = scheduler
```
"""
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
@@ -428,7 +428,7 @@ class DiffusionPipeline(ConfigMixin):
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path):
config_dict = cls.get_config_dict(
config_dict = cls.load_config(
pretrained_model_name_or_path,
cache_dir=cache_dir,
resume_download=resume_download,
@@ -474,7 +474,7 @@ class DiffusionPipeline(ConfigMixin):
else:
cached_folder = pretrained_model_name_or_path
config_dict = cls.get_config_dict(cached_folder)
config_dict = cls.load_config(cached_folder)
# 2. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
@@ -513,7 +513,7 @@ class DiffusionPipeline(ConfigMixin):
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"])
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
init_dict, unused_kwargs = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
if len(unused_kwargs) > 0:
logger.warning(f"Keyword arguments {unused_kwargs} not recognized.")

View File

@@ -71,7 +71,7 @@ class DDPMPipeline(DiffusionPipeline):
"""
message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)

View File

@@ -72,7 +72,7 @@ image.save("astronaut_rides_horse.png")
# make sure you're logged in with `huggingface-cli login`
from diffusers import StableDiffusionPipeline, DDIMScheduler
scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
@@ -91,7 +91,7 @@ image.save("astronaut_rides_horse.png")
# make sure you're logged in with `huggingface-cli login`
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
lms = LMSDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
@@ -120,7 +120,7 @@ from diffusers import CycleDiffusionPipeline, DDIMScheduler
# load the pipeline
# make sure you're logged in with `huggingface-cli login`
model_id_or_path = "CompVis/stable-diffusion-v1-4"
scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler")
scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler")
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")
# let's download an initial image

View File

@@ -23,7 +23,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
from .scheduling_utils import SchedulerMixin
@@ -82,8 +82,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2010.02502
@@ -109,14 +109,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
"""
_compatible_classes = [
"PNDMScheduler",
"DDPMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
]
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@register_to_config
def __init__(

View File

@@ -23,7 +23,12 @@ import flax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
)
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
@@ -79,8 +84,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2010.02502
@@ -105,6 +110,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
stable diffusion.
"""
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@property
def has_state(self):
return True

View File

@@ -22,7 +22,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ..utils import BaseOutput, deprecate
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin
@@ -80,8 +80,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2006.11239
@@ -104,14 +104,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"""
_compatible_classes = [
"DDIMScheduler",
"PNDMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
]
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@register_to_config
def __init__(
@@ -249,7 +242,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"""
message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:

View File

@@ -24,7 +24,12 @@ from jax import random
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
)
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
@@ -79,8 +84,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2006.11239
@@ -103,6 +108,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@property
def has_state(self):
return True
@@ -221,7 +228,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:

View File

@@ -21,6 +21,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from .scheduling_utils import SchedulerMixin, SchedulerOutput
@@ -71,8 +72,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
@@ -116,14 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
_compatible_classes = [
"DDIMScheduler",
"DDPMScheduler",
"PNDMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
]
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@register_to_config
def __init__(

View File

@@ -23,7 +23,12 @@ import jax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
)
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
@@ -96,8 +101,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
@@ -143,6 +148,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@property
def has_state(self):
return True

View File

@@ -19,7 +19,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging
from .scheduling_utils import SchedulerMixin
@@ -52,8 +52,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
@@ -67,14 +67,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
_compatible_classes = [
"DDIMScheduler",
"DDPMScheduler",
"LMSDiscreteScheduler",
"PNDMScheduler",
"EulerDiscreteScheduler",
"DPMSolverMultistepScheduler",
]
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@register_to_config
def __init__(

View File

@@ -19,7 +19,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging
from .scheduling_utils import SchedulerMixin
@@ -53,8 +53,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
@@ -68,14 +68,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
_compatible_classes = [
"DDIMScheduler",
"DDPMScheduler",
"LMSDiscreteScheduler",
"PNDMScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
]
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@register_to_config
def __init__(

View File

@@ -28,8 +28,8 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2202.09778

View File

@@ -56,8 +56,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the

View File

@@ -67,8 +67,8 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the

View File

@@ -21,7 +21,7 @@ import torch
from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
from .scheduling_utils import SchedulerMixin
@@ -52,8 +52,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
@@ -67,14 +67,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
_compatible_classes = [
"DDIMScheduler",
"DDPMScheduler",
"PNDMScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
]
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@register_to_config
def __init__(

View File

@@ -20,7 +20,12 @@ import jax.numpy as jnp
from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
)
@flax.struct.dataclass
@@ -49,8 +54,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
@@ -63,6 +68,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
"""
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@property
def has_state(self):
return True

View File

@@ -21,6 +21,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from .scheduling_utils import SchedulerMixin, SchedulerOutput
@@ -60,8 +61,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2202.09778
@@ -88,14 +89,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
"""
_compatible_classes = [
"DDIMScheduler",
"DDPMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
]
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@register_to_config
def __init__(

View File

@@ -23,7 +23,12 @@ import jax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
)
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
@@ -87,8 +92,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2202.09778
@@ -114,6 +119,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
stable diffusion.
"""
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@property
def has_state(self):
return True

View File

@@ -77,8 +77,8 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/pdf/2201.09865.pdf

View File

@@ -50,8 +50,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.

View File

@@ -64,8 +64,8 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.

View File

@@ -29,8 +29,8 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
For more information, see the original paper: https://arxiv.org/abs/2011.13456

View File

@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
import torch
@@ -38,6 +41,114 @@ class SchedulerOutput(BaseOutput):
class SchedulerMixin:
"""
Mixin containing common functions for the schedulers.
Class attributes:
- **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
`from_config` can be used from a class different than the one used to save the config (should be overridden
by parent class).
"""
config_name = SCHEDULER_CONFIG_NAME
_compatibles = []
has_compatibles = True
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Dict[str, Any] = None,
subfolder: Optional[str] = None,
return_unused_kwargs=False,
**kwargs,
):
r"""
Instantiate a Scheduler class from a pre-defined JSON configuration file inside a directory or Hub repo.
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
organization name, like `google/ddpm-celebahq-256`.
- A path to a *directory* containing the schedluer configurations saved using
[`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`.
subfolder (`str`, *optional*):
In case the relevant files are located inside a subfolder of the model repo (either remote in
huggingface.co or downloaded locally), you can specify the folder name here.
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
Whether kwargs that are not consumed by the Python class should be returned or not.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
<Tip>
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
models](https://huggingface.co/docs/hub/models-gated#gated-models).
</Tip>
<Tip>
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
use this method in a firewalled environment.
</Tip>
"""
config, kwargs = cls.load_config(
pretrained_model_name_or_path=pretrained_model_name_or_path,
subfolder=subfolder,
return_unused_kwargs=True,
**kwargs,
)
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
"""
Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
[`~SchedulerMixin.from_pretrained`] class method.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the configuration JSON file will be saved (will be created if it does not exist).
"""
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
@property
def compatibles(self):
"""
Returns all schedulers that are compatible with this scheduler
Returns:
`List[SchedulerMixin]`: List of compatible schedulers
"""
return self._get_compatibles()
@classmethod
def _get_compatibles(cls):
compatible_classes_str = list(set([cls.__name__] + cls._compatibles))
diffusers_library = importlib.import_module(__name__.split(".")[0])
compatible_classes = [
getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
]
return compatible_classes

View File

@@ -11,15 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
from dataclasses import dataclass
from typing import Tuple
from typing import Any, Dict, Optional, Tuple, Union
import jax.numpy as jnp
from ..utils import BaseOutput
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = ["Flax" + c for c in _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS]
@dataclass
@@ -39,9 +42,123 @@ class FlaxSchedulerOutput(BaseOutput):
class FlaxSchedulerMixin:
"""
Mixin containing common functions for the schedulers.
Class attributes:
- **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
`from_config` can be used from a class different than the one used to save the config (should be overridden
by parent class).
"""
config_name = SCHEDULER_CONFIG_NAME
_compatibles = []
has_compatibles = True
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Dict[str, Any] = None,
subfolder: Optional[str] = None,
return_unused_kwargs=False,
**kwargs,
):
r"""
Instantiate a Scheduler class from a pre-defined JSON-file.
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
organization name, like `google/ddpm-celebahq-256`.
- A path to a *directory* containing model weights saved using [`~SchedulerMixin.save_pretrained`],
e.g., `./my_model_directory/`.
subfolder (`str`, *optional*):
In case the relevant files are located inside a subfolder of the model repo (either remote in
huggingface.co or downloaded locally), you can specify the folder name here.
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
Whether kwargs that are not consumed by the Python class should be returned or not.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
<Tip>
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
models](https://huggingface.co/docs/hub/models-gated#gated-models).
</Tip>
<Tip>
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
use this method in a firewalled environment.
</Tip>
"""
config, kwargs = cls.load_config(
pretrained_model_name_or_path=pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
scheduler, unused_kwargs = cls.from_config(config, return_unused_kwargs=True, **kwargs)
if hasattr(scheduler, "create_state") and getattr(scheduler, "has_state", False):
state = scheduler.create_state()
if return_unused_kwargs:
return scheduler, state, unused_kwargs
return scheduler, state
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
"""
Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
[`~FlaxSchedulerMixin.from_pretrained`] class method.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the configuration JSON file will be saved (will be created if it does not exist).
"""
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
@property
def compatibles(self):
"""
Returns all schedulers that are compatible with this scheduler
Returns:
`List[SchedulerMixin]`: List of compatible schedulers
"""
return self._get_compatibles()
@classmethod
def _get_compatibles(cls):
compatible_classes_str = list(set([cls.__name__] + cls._compatibles))
diffusers_library = importlib.import_module(__name__.split(".")[0])
compatible_classes = [
getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
]
return compatible_classes
def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray:

View File

@@ -112,8 +112,8 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2111.14822

View File

@@ -72,3 +72,13 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [
"DDIMScheduler",
"DDPMScheduler",
"PNDMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
]

View File

@@ -67,8 +67,8 @@ class UNet1DModelTests(ModelTesterMixin, unittest.TestCase):
super().test_from_pretrained_save_pretrained()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_model_from_config(self):
super().test_model_from_config()
def test_model_from_pretrained(self):
super().test_model_from_pretrained()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_output(self):
@@ -187,8 +187,8 @@ class UNetRLModelTests(ModelTesterMixin, unittest.TestCase):
super().test_from_pretrained_save_pretrained()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_model_from_config(self):
super().test_model_from_config()
def test_model_from_pretrained(self):
super().test_model_from_pretrained()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_output(self):

View File

@@ -75,7 +75,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
model_id = "google/ddpm-ema-bedroom-256"
unet = UNet2DModel.from_pretrained(model_id)
scheduler = DDIMScheduler.from_config(model_id)
scheduler = DDIMScheduler.from_pretrained(model_id)
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)

View File

@@ -106,7 +106,7 @@ class DDPMPipelineIntegrationTests(unittest.TestCase):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id)
scheduler = DDPMScheduler.from_config(model_id)
scheduler = DDPMScheduler.from_pretrained(model_id)
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)

View File

@@ -44,7 +44,7 @@ class RepaintPipelineIntegrationTests(unittest.TestCase):
model_id = "google/ddpm-ema-celebahq-256"
unet = UNet2DModel.from_pretrained(model_id)
scheduler = RePaintScheduler.from_config(model_id)
scheduler = RePaintScheduler.from_pretrained(model_id)
repaint = RePaintPipeline(unet=unet, scheduler=scheduler).to(torch_device)

View File

@@ -74,7 +74,7 @@ class ScoreSdeVePipelineIntegrationTests(unittest.TestCase):
model_id = "google/ncsnpp-church-256"
model = UNet2DModel.from_pretrained(model_id)
scheduler = ScoreSdeVeScheduler.from_config(model_id)
scheduler = ScoreSdeVeScheduler.from_pretrained(model_id)
sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
sde_ve.to(torch_device)

View File

@@ -281,7 +281,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
init_image = init_image.resize((512, 512))
model_id = "CompVis/stable-diffusion-v1-4"
scheduler = DDIMScheduler.from_config(model_id, subfolder="scheduler")
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = CycleDiffusionPipeline.from_pretrained(
model_id, scheduler=scheduler, safety_checker=None, torch_dtype=torch.float16, revision="fp16"
)
@@ -322,7 +322,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
init_image = init_image.resize((512, 512))
model_id = "CompVis/stable-diffusion-v1-4"
scheduler = DDIMScheduler.from_config(model_id, subfolder="scheduler")
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = CycleDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, safety_checker=None)
pipe.to(torch_device)

View File

@@ -75,7 +75,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_inference_ddim(self):
ddim_scheduler = DDIMScheduler.from_config(
ddim_scheduler = DDIMScheduler.from_pretrained(
"runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx"
)
sd_pipe = OnnxStableDiffusionPipeline.from_pretrained(
@@ -98,7 +98,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_inference_k_lms(self):
lms_scheduler = LMSDiscreteScheduler.from_config(
lms_scheduler = LMSDiscreteScheduler.from_pretrained(
"runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx"
)
sd_pipe = OnnxStableDiffusionPipeline.from_pretrained(

View File

@@ -93,7 +93,7 @@ class OnnxStableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
"/img2img/sketch-mountains-input.jpg"
)
init_image = init_image.resize((768, 512))
lms_scheduler = LMSDiscreteScheduler.from_config(
lms_scheduler = LMSDiscreteScheduler.from_pretrained(
"runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx"
)
pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(

View File

@@ -97,7 +97,7 @@ class OnnxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
)
lms_scheduler = LMSDiscreteScheduler.from_config(
lms_scheduler = LMSDiscreteScheduler.from_pretrained(
"runwayml/stable-diffusion-inpainting", subfolder="scheduler", revision="onnx"
)
pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained(

View File

@@ -703,7 +703,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_fast_ddim(self):
scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-1", subfolder="scheduler")
scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-1", subfolder="scheduler")
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", scheduler=scheduler)
sd_pipe = sd_pipe.to(torch_device)
@@ -726,7 +726,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
model_id = "CompVis/stable-diffusion-v1-1"
pipe = StableDiffusionPipeline.from_pretrained(model_id).to(torch_device)
pipe.set_progress_bar_config(disable=None)
scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
scheduler = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe.scheduler = scheduler
prompt = "a photograph of an astronaut riding a horse"

View File

@@ -520,7 +520,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
)
model_id = "CompVis/stable-diffusion-v1-4"
lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id,
scheduler=lms,
@@ -557,7 +557,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
)
model_id = "CompVis/stable-diffusion-v1-4"
ddim = DDIMScheduler.from_config(model_id, subfolder="scheduler")
ddim = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id,
scheduler=ddim,
@@ -649,7 +649,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
init_image = init_image.resize((768, 512))
model_id = "CompVis/stable-diffusion-v1-4"
lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id, scheduler=lms, safety_checker=None, device_map="auto", revision="fp16", torch_dtype=torch.float16
)

View File

@@ -400,7 +400,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
)
model_id = "runwayml/stable-diffusion-inpainting"
pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler")
pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None, scheduler=pndm)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -437,7 +437,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
)
model_id = "runwayml/stable-diffusion-inpainting"
pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler")
pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id,
safety_checker=None,

View File

@@ -401,7 +401,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
)
model_id = "CompVis/stable-diffusion-v1-4"
lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id,
scheduler=lms,

View File

@@ -13,12 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import tempfile
import unittest
import diffusers
from diffusers import (
DDIMScheduler,
DDPMScheduler,
@@ -81,7 +78,7 @@ class SampleObject3(ConfigMixin):
class ConfigTester(unittest.TestCase):
def test_load_not_from_mixin(self):
with self.assertRaises(ValueError):
ConfigMixin.from_config("dummy_path")
ConfigMixin.load_config("dummy_path")
def test_register_to_config(self):
obj = SampleObject()
@@ -131,7 +128,7 @@ class ConfigTester(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
new_obj = SampleObject.from_config(tmpdirname)
new_obj = SampleObject.from_config(SampleObject.load_config(tmpdirname))
new_config = new_obj.config
# unfreeze configs
@@ -142,117 +139,13 @@ class ConfigTester(unittest.TestCase):
assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json
assert config == new_config
def test_save_load_from_different_config(self):
obj = SampleObject()
# mock add obj class to `diffusers`
setattr(diffusers, "SampleObject", SampleObject)
logger = logging.get_logger("diffusers.configuration_utils")
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_1:
new_obj_1 = SampleObject2.from_config(tmpdirname)
# now save a config parameter that is not expected
with open(os.path.join(tmpdirname, SampleObject.config_name), "r") as f:
data = json.load(f)
data["unexpected"] = True
with open(os.path.join(tmpdirname, SampleObject.config_name), "w") as f:
json.dump(data, f)
with CaptureLogger(logger) as cap_logger_2:
new_obj_2 = SampleObject.from_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_3:
new_obj_3 = SampleObject2.from_config(tmpdirname)
assert new_obj_1.__class__ == SampleObject2
assert new_obj_2.__class__ == SampleObject
assert new_obj_3.__class__ == SampleObject2
assert cap_logger_1.out == ""
assert (
cap_logger_2.out
== "The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will"
" be ignored. Please verify your config.json configuration file.\n"
)
assert cap_logger_2.out.replace("SampleObject", "SampleObject2") == cap_logger_3.out
def test_save_load_compatible_schedulers(self):
SampleObject2._compatible_classes = ["SampleObject"]
SampleObject._compatible_classes = ["SampleObject2"]
obj = SampleObject()
# mock add obj class to `diffusers`
setattr(diffusers, "SampleObject", SampleObject)
setattr(diffusers, "SampleObject2", SampleObject2)
logger = logging.get_logger("diffusers.configuration_utils")
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
# now save a config parameter that is expected by another class, but not origin class
with open(os.path.join(tmpdirname, SampleObject.config_name), "r") as f:
data = json.load(f)
data["f"] = [0, 0]
data["unexpected"] = True
with open(os.path.join(tmpdirname, SampleObject.config_name), "w") as f:
json.dump(data, f)
with CaptureLogger(logger) as cap_logger:
new_obj = SampleObject.from_config(tmpdirname)
assert new_obj.__class__ == SampleObject
assert (
cap_logger.out
== "The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will"
" be ignored. Please verify your config.json configuration file.\n"
)
def test_save_load_from_different_config_comp_schedulers(self):
SampleObject3._compatible_classes = ["SampleObject", "SampleObject2"]
SampleObject2._compatible_classes = ["SampleObject", "SampleObject3"]
SampleObject._compatible_classes = ["SampleObject2", "SampleObject3"]
obj = SampleObject()
# mock add obj class to `diffusers`
setattr(diffusers, "SampleObject", SampleObject)
setattr(diffusers, "SampleObject2", SampleObject2)
setattr(diffusers, "SampleObject3", SampleObject3)
logger = logging.get_logger("diffusers.configuration_utils")
logger.setLevel(diffusers.logging.INFO)
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_1:
new_obj_1 = SampleObject.from_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_2:
new_obj_2 = SampleObject2.from_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_3:
new_obj_3 = SampleObject3.from_config(tmpdirname)
assert new_obj_1.__class__ == SampleObject
assert new_obj_2.__class__ == SampleObject2
assert new_obj_3.__class__ == SampleObject3
assert cap_logger_1.out == ""
assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
def test_load_ddim_from_pndm(self):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
ddim = DDIMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
ddim = DDIMScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
)
assert ddim.__class__ == DDIMScheduler
# no warning should be thrown
@@ -262,7 +155,7 @@ class ConfigTester(unittest.TestCase):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
euler = EulerDiscreteScheduler.from_config(
euler = EulerDiscreteScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
)
@@ -274,7 +167,7 @@ class ConfigTester(unittest.TestCase):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
euler = EulerAncestralDiscreteScheduler.from_config(
euler = EulerAncestralDiscreteScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
)
@@ -286,7 +179,9 @@ class ConfigTester(unittest.TestCase):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
pndm = PNDMScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
)
assert pndm.__class__ == PNDMScheduler
# no warning should be thrown
@@ -296,7 +191,7 @@ class ConfigTester(unittest.TestCase):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
ddpm = DDPMScheduler.from_config(
ddpm = DDPMScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch",
subfolder="scheduler",
predict_epsilon=False,
@@ -304,7 +199,7 @@ class ConfigTester(unittest.TestCase):
)
with CaptureLogger(logger) as cap_logger_2:
ddpm_2 = DDPMScheduler.from_config("google/ddpm-celebahq-256", beta_start=88)
ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88)
assert ddpm.__class__ == DDPMScheduler
assert ddpm.config.predict_epsilon is False
@@ -319,7 +214,7 @@ class ConfigTester(unittest.TestCase):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
dpm = DPMSolverMultistepScheduler.from_config(
dpm = DPMSolverMultistepScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
)

View File

@@ -130,7 +130,7 @@ class ModelTesterMixin:
expected_arg_names = ["sample", "timestep"]
self.assertListEqual(arg_names[:2], expected_arg_names)
def test_model_from_config(self):
def test_model_from_pretrained(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
@@ -140,8 +140,8 @@ class ModelTesterMixin:
# test if the model can be loaded from the config
# and has all the expected shape
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_config(tmpdirname)
new_model = self.model_class.from_config(tmpdirname)
model.save_pretrained(tmpdirname)
new_model = self.model_class.from_pretrained(tmpdirname)
new_model.to(torch_device)
new_model.eval()

View File

@@ -29,6 +29,10 @@ from diffusers import (
DDIMScheduler,
DDPMPipeline,
DDPMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipelineLegacy,
@@ -398,6 +402,82 @@ class PipelineFastTests(unittest.TestCase):
assert image_img2img.shape == (1, 32, 32, 3)
assert image_text2img.shape == (1, 128, 128, 3)
def test_set_scheduler(self):
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
sd = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
sd.scheduler = DDIMScheduler.from_config(sd.scheduler.config)
assert isinstance(sd.scheduler, DDIMScheduler)
sd.scheduler = DDPMScheduler.from_config(sd.scheduler.config)
assert isinstance(sd.scheduler, DDPMScheduler)
sd.scheduler = PNDMScheduler.from_config(sd.scheduler.config)
assert isinstance(sd.scheduler, PNDMScheduler)
sd.scheduler = LMSDiscreteScheduler.from_config(sd.scheduler.config)
assert isinstance(sd.scheduler, LMSDiscreteScheduler)
sd.scheduler = EulerDiscreteScheduler.from_config(sd.scheduler.config)
assert isinstance(sd.scheduler, EulerDiscreteScheduler)
sd.scheduler = EulerAncestralDiscreteScheduler.from_config(sd.scheduler.config)
assert isinstance(sd.scheduler, EulerAncestralDiscreteScheduler)
sd.scheduler = DPMSolverMultistepScheduler.from_config(sd.scheduler.config)
assert isinstance(sd.scheduler, DPMSolverMultistepScheduler)
def test_set_scheduler_consistency(self):
unet = self.dummy_cond_unet
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
ddim = DDIMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
sd = StableDiffusionPipeline(
unet=unet,
scheduler=pndm,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
pndm_config = sd.scheduler.config
sd.scheduler = DDPMScheduler.from_config(pndm_config)
sd.scheduler = PNDMScheduler.from_config(sd.scheduler.config)
pndm_config_2 = sd.scheduler.config
pndm_config_2 = {k: v for k, v in pndm_config_2.items() if k in pndm_config}
assert dict(pndm_config) == dict(pndm_config_2)
sd = StableDiffusionPipeline(
unet=unet,
scheduler=ddim,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
ddim_config = sd.scheduler.config
sd.scheduler = LMSDiscreteScheduler.from_config(ddim_config)
sd.scheduler = DDIMScheduler.from_config(sd.scheduler.config)
ddim_config_2 = sd.scheduler.config
ddim_config_2 = {k: v for k, v in ddim_config_2.items() if k in ddim_config}
assert dict(ddim_config) == dict(ddim_config_2)
@slow
class PipelineSlowTests(unittest.TestCase):
@@ -519,7 +599,7 @@ class PipelineSlowTests(unittest.TestCase):
def test_output_format(self):
model_path = "google/ddpm-cifar10-32"
scheduler = DDIMScheduler.from_config(model_path)
scheduler = DDIMScheduler.from_pretrained(model_path)
pipe = DDIMPipeline.from_pretrained(model_path, scheduler=scheduler)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import os
import tempfile
import unittest
from typing import Dict, List, Tuple
@@ -21,6 +23,7 @@ import numpy as np
import torch
import torch.nn.functional as F
import diffusers
from diffusers import (
DDIMScheduler,
DDPMScheduler,
@@ -32,13 +35,180 @@ from diffusers import (
PNDMScheduler,
ScoreSdeVeScheduler,
VQDiffusionScheduler,
logging,
)
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils import deprecate, torch_device
from diffusers.utils.testing_utils import CaptureLogger
torch.backends.cuda.matmul.allow_tf32 = False
class SchedulerObject(SchedulerMixin, ConfigMixin):
config_name = "config.json"
@register_to_config
def __init__(
self,
a=2,
b=5,
c=(2, 5),
d="for diffusion",
e=[1, 3],
):
pass
class SchedulerObject2(SchedulerMixin, ConfigMixin):
config_name = "config.json"
@register_to_config
def __init__(
self,
a=2,
b=5,
c=(2, 5),
d="for diffusion",
f=[1, 3],
):
pass
class SchedulerObject3(SchedulerMixin, ConfigMixin):
config_name = "config.json"
@register_to_config
def __init__(
self,
a=2,
b=5,
c=(2, 5),
d="for diffusion",
e=[1, 3],
f=[1, 3],
):
pass
class SchedulerBaseTests(unittest.TestCase):
def test_save_load_from_different_config(self):
obj = SchedulerObject()
# mock add obj class to `diffusers`
setattr(diffusers, "SchedulerObject", SchedulerObject)
logger = logging.get_logger("diffusers.configuration_utils")
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_1:
config = SchedulerObject2.load_config(tmpdirname)
new_obj_1 = SchedulerObject2.from_config(config)
# now save a config parameter that is not expected
with open(os.path.join(tmpdirname, SchedulerObject.config_name), "r") as f:
data = json.load(f)
data["unexpected"] = True
with open(os.path.join(tmpdirname, SchedulerObject.config_name), "w") as f:
json.dump(data, f)
with CaptureLogger(logger) as cap_logger_2:
config = SchedulerObject.load_config(tmpdirname)
new_obj_2 = SchedulerObject.from_config(config)
with CaptureLogger(logger) as cap_logger_3:
config = SchedulerObject2.load_config(tmpdirname)
new_obj_3 = SchedulerObject2.from_config(config)
assert new_obj_1.__class__ == SchedulerObject2
assert new_obj_2.__class__ == SchedulerObject
assert new_obj_3.__class__ == SchedulerObject2
assert cap_logger_1.out == ""
assert (
cap_logger_2.out
== "The config attributes {'unexpected': True} were passed to SchedulerObject, but are not expected and"
" will"
" be ignored. Please verify your config.json configuration file.\n"
)
assert cap_logger_2.out.replace("SchedulerObject", "SchedulerObject2") == cap_logger_3.out
def test_save_load_compatible_schedulers(self):
SchedulerObject2._compatibles = ["SchedulerObject"]
SchedulerObject._compatibles = ["SchedulerObject2"]
obj = SchedulerObject()
# mock add obj class to `diffusers`
setattr(diffusers, "SchedulerObject", SchedulerObject)
setattr(diffusers, "SchedulerObject2", SchedulerObject2)
logger = logging.get_logger("diffusers.configuration_utils")
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
# now save a config parameter that is expected by another class, but not origin class
with open(os.path.join(tmpdirname, SchedulerObject.config_name), "r") as f:
data = json.load(f)
data["f"] = [0, 0]
data["unexpected"] = True
with open(os.path.join(tmpdirname, SchedulerObject.config_name), "w") as f:
json.dump(data, f)
with CaptureLogger(logger) as cap_logger:
config = SchedulerObject.load_config(tmpdirname)
new_obj = SchedulerObject.from_config(config)
assert new_obj.__class__ == SchedulerObject
assert (
cap_logger.out
== "The config attributes {'unexpected': True} were passed to SchedulerObject, but are not expected and"
" will"
" be ignored. Please verify your config.json configuration file.\n"
)
def test_save_load_from_different_config_comp_schedulers(self):
SchedulerObject3._compatibles = ["SchedulerObject", "SchedulerObject2"]
SchedulerObject2._compatibles = ["SchedulerObject", "SchedulerObject3"]
SchedulerObject._compatibles = ["SchedulerObject2", "SchedulerObject3"]
obj = SchedulerObject()
# mock add obj class to `diffusers`
setattr(diffusers, "SchedulerObject", SchedulerObject)
setattr(diffusers, "SchedulerObject2", SchedulerObject2)
setattr(diffusers, "SchedulerObject3", SchedulerObject3)
logger = logging.get_logger("diffusers.configuration_utils")
logger.setLevel(diffusers.logging.INFO)
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
with CaptureLogger(logger) as cap_logger_1:
config = SchedulerObject.load_config(tmpdirname)
new_obj_1 = SchedulerObject.from_config(config)
with CaptureLogger(logger) as cap_logger_2:
config = SchedulerObject2.load_config(tmpdirname)
new_obj_2 = SchedulerObject2.from_config(config)
with CaptureLogger(logger) as cap_logger_3:
config = SchedulerObject3.load_config(tmpdirname)
new_obj_3 = SchedulerObject3.from_config(config)
assert new_obj_1.__class__ == SchedulerObject
assert new_obj_2.__class__ == SchedulerObject2
assert new_obj_3.__class__ == SchedulerObject3
assert cap_logger_1.out == ""
assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
class SchedulerCommonTest(unittest.TestCase):
scheduler_classes = ()
forward_default_kwargs = ()
@@ -102,7 +272,7 @@ class SchedulerCommonTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
@@ -145,7 +315,7 @@ class SchedulerCommonTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
@@ -187,7 +357,7 @@ class SchedulerCommonTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
@@ -205,6 +375,42 @@ class SchedulerCommonTest(unittest.TestCase):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_compatibles(self):
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
assert all(c is not None for c in scheduler.compatibles)
for comp_scheduler_cls in scheduler.compatibles:
comp_scheduler = comp_scheduler_cls.from_config(scheduler.config)
assert comp_scheduler is not None
new_scheduler = scheduler_class.from_config(comp_scheduler.config)
new_scheduler_config = {k: v for k, v in new_scheduler.config.items() if k in scheduler.config}
scheduler_diff = {k: v for k, v in new_scheduler.config.items() if k not in scheduler.config}
# make sure that configs are essentially identical
assert new_scheduler_config == dict(scheduler.config)
# make sure that only differences are for configs that are not in init
init_keys = inspect.signature(scheduler_class.__init__).parameters.keys()
assert set(scheduler_diff.keys()).intersection(set(init_keys)) == set()
def test_from_pretrained(self):
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_pretrained(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
assert scheduler.config == new_scheduler.config
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
@@ -616,7 +822,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
@@ -648,7 +854,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
# copy over dummy past residuals
new_scheduler.set_timesteps(num_inference_steps)
@@ -790,7 +996,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:]
@@ -825,7 +1031,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
# copy over dummy past residuals
new_scheduler.set_timesteps(num_inference_steps)
@@ -1043,7 +1249,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
output = scheduler.step_pred(
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
@@ -1074,7 +1280,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
output = scheduler.step_pred(
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
@@ -1470,7 +1676,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:]
@@ -1508,7 +1714,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
# copy over dummy past residuals
new_scheduler.set_timesteps(num_inference_steps)

View File

@@ -83,7 +83,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
@@ -112,7 +112,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
@@ -140,7 +140,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
@@ -373,7 +373,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
@@ -401,7 +401,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
@@ -430,7 +430,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
@@ -633,7 +633,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)
# copy over dummy past residuals
new_state = new_state.replace(ets=dummy_past_residuals[:])
@@ -720,7 +720,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
# copy over dummy past residuals
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)