mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add Scheduler.from_pretrained and better scheduler changing (#1286)
* add conversion script for vae * uP * uP * more changes * push * up * finish again * up * up * up * up * finish * up * uP * up * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Anton Lozhkov <anton@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com> * up * up Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Anton Lozhkov <anton@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
committed by
GitHub
parent
db1cb0b1a2
commit
a0520193e1
10
README.md
10
README.md
@@ -152,15 +152,7 @@ it before the pipeline and pass it to `from_pretrained`.
|
||||
```python
|
||||
from diffusers import LMSDiscreteScheduler
|
||||
|
||||
lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
scheduler=lms,
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
image = pipe(prompt).images[0]
|
||||
|
||||
@@ -10,6 +10,8 @@
|
||||
- sections:
|
||||
- local: using-diffusers/loading
|
||||
title: "Loading Pipelines, Models, and Schedulers"
|
||||
- local: using-diffusers/schedulers
|
||||
title: "Using different Schedulers"
|
||||
- local: using-diffusers/configuration
|
||||
title: "Configuring Pipelines, Models, and Schedulers"
|
||||
- local: using-diffusers/custom_pipeline_overview
|
||||
|
||||
@@ -15,9 +15,9 @@ specific language governing permissions and limitations under the License.
|
||||
In Diffusers, schedulers of type [`schedulers.scheduling_utils.SchedulerMixin`], and models of type [`ModelMixin`] inherit from [`ConfigMixin`] which conveniently takes care of storing all parameters that are
|
||||
passed to the respective `__init__` methods in a JSON-configuration file.
|
||||
|
||||
TODO(PVP) - add example and better info here
|
||||
|
||||
## ConfigMixin
|
||||
|
||||
[[autodoc]] ConfigMixin
|
||||
- load_config
|
||||
- from_config
|
||||
- save_config
|
||||
|
||||
@@ -39,7 +39,7 @@ from diffusers import CycleDiffusionPipeline, DDIMScheduler
|
||||
# load the pipeline
|
||||
# make sure you're logged in with `huggingface-cli login`
|
||||
model_id_or_path = "CompVis/stable-diffusion-v1-4"
|
||||
scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler")
|
||||
scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler")
|
||||
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")
|
||||
|
||||
# let's download an initial image
|
||||
|
||||
@@ -54,7 +54,7 @@ original_image = download_image(img_url).resize((256, 256))
|
||||
mask_image = download_image(mask_url).resize((256, 256))
|
||||
|
||||
# Load the RePaint scheduler and pipeline based on a pretrained DDPM model
|
||||
scheduler = RePaintScheduler.from_config("google/ddpm-ema-celebahq-256")
|
||||
scheduler = RePaintScheduler.from_pretrained("google/ddpm-ema-celebahq-256")
|
||||
pipe = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=scheduler)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
|
||||
@@ -34,13 +34,17 @@ For more details about how Stable Diffusion works and how it differs from the ba
|
||||
### How to load and use different schedulers.
|
||||
|
||||
The stable diffusion pipeline uses [`PNDMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
|
||||
To use a different scheduler, you can pass the `scheduler` argument to `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
|
||||
To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
|
||||
>>> from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
|
||||
|
||||
euler_scheduler = EulerDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=euler_scheduler)
|
||||
>>> pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
|
||||
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
>>> # or
|
||||
>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||
>>> pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=euler_scheduler)
|
||||
```
|
||||
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ In this guide though, you'll use [`DiffusionPipeline`] for text-to-image generat
|
||||
```python
|
||||
>>> from diffusers import DiffusionPipeline
|
||||
|
||||
>>> generator = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
||||
>>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
||||
```
|
||||
|
||||
The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components.
|
||||
@@ -49,13 +49,13 @@ Because the model consists of roughly 1.4 billion parameters, we strongly recomm
|
||||
You can move the generator object to GPU, just like you would in PyTorch.
|
||||
|
||||
```python
|
||||
>>> generator.to("cuda")
|
||||
>>> pipeline.to("cuda")
|
||||
```
|
||||
|
||||
Now you can use the `generator` on your text prompt:
|
||||
Now you can use the `pipeline` on your text prompt:
|
||||
|
||||
```python
|
||||
>>> image = generator("An image of a squirrel in Picasso style").images[0]
|
||||
>>> image = pipeline("An image of a squirrel in Picasso style").images[0]
|
||||
```
|
||||
|
||||
The output is by default wrapped into a [PIL Image object](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class).
|
||||
@@ -82,7 +82,7 @@ just like we did before only that now you need to pass your `AUTH_TOKEN`:
|
||||
```python
|
||||
>>> from diffusers import DiffusionPipeline
|
||||
|
||||
>>> generator = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN)
|
||||
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN)
|
||||
```
|
||||
|
||||
If you do not pass your authentication token you will see that the diffusion system will not be correctly
|
||||
@@ -102,7 +102,7 @@ token. Assuming that `"./stable-diffusion-v1-5"` is the local path to the cloned
|
||||
you can also load the pipeline as follows:
|
||||
|
||||
```python
|
||||
>>> generator = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5")
|
||||
>>> pipeline = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5")
|
||||
```
|
||||
|
||||
Running the pipeline is then identical to the code above as it's the same model architecture.
|
||||
@@ -115,19 +115,20 @@ Running the pipeline is then identical to the code above as it's the same model
|
||||
|
||||
Diffusion systems can be used with multiple different [schedulers](./api/schedulers) each with their
|
||||
pros and cons. By default, Stable Diffusion runs with [`PNDMScheduler`], but it's very simple to
|
||||
use a different scheduler. *E.g.* if you would instead like to use the [`LMSDiscreteScheduler`] scheduler,
|
||||
use a different scheduler. *E.g.* if you would instead like to use the [`EulerDiscreteScheduler`] scheduler,
|
||||
you could use it as follows:
|
||||
|
||||
```python
|
||||
>>> from diffusers import LMSDiscreteScheduler
|
||||
>>> from diffusers import EulerDiscreteScheduler
|
||||
|
||||
>>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
|
||||
>>> pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN)
|
||||
|
||||
>>> generator = StableDiffusionPipeline.from_pretrained(
|
||||
... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, use_auth_token=AUTH_TOKEN
|
||||
... )
|
||||
>>> # change scheduler to Euler
|
||||
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
```
|
||||
|
||||
For more in-detail information on how to change between schedulers, please refer to the [Using Schedulers](./using-diffusers/schedulers) guide.
|
||||
|
||||
[Stability AI's](https://stability.ai/) Stable Diffusion model is an impressive image generation model
|
||||
and can do much more than just generating images from text. We have dedicated a whole documentation page,
|
||||
just for Stable Diffusion [here](./conceptual/stable_diffusion).
|
||||
|
||||
@@ -19,7 +19,7 @@ In the following we explain in-detail how to easily load:
|
||||
|
||||
- *Complete Diffusion Pipelines* via the [`DiffusionPipeline.from_pretrained`]
|
||||
- *Diffusion Models* via [`ModelMixin.from_pretrained`]
|
||||
- *Schedulers* via [`ConfigMixin.from_config`]
|
||||
- *Schedulers* via [`SchedulerMixin.from_pretrained`]
|
||||
|
||||
## Loading pipelines
|
||||
|
||||
@@ -137,15 +137,15 @@ from diffusers import DiffusionPipeline, EulerDiscreteScheduler, DPMSolverMultis
|
||||
|
||||
repo_id = "runwayml/stable-diffusion-v1-5"
|
||||
|
||||
scheduler = EulerDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
scheduler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
|
||||
# or
|
||||
# scheduler = DPMSolverMultistepScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
# scheduler = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler")
|
||||
|
||||
stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, scheduler=scheduler)
|
||||
```
|
||||
|
||||
Three things are worth paying attention to here.
|
||||
- First, the scheduler is loaded with [`ConfigMixin.from_config`] since it only depends on a configuration file and not any parameterized weights
|
||||
- First, the scheduler is loaded with [`SchedulerMixin.from_pretrained`]
|
||||
- Second, the scheduler is loaded with a function argument, called `subfolder="scheduler"` as the configuration of stable diffusion's scheduling is defined in a [subfolder of the official pipeline repository](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/scheduler)
|
||||
- Third, the scheduler instance can simply be passed with the `scheduler` keyword argument to [`DiffusionPipeline.from_pretrained`]. This works because the [`StableDiffusionPipeline`] defines its scheduler with the `scheduler` attribute. It's not possible to use a different name, such as `sampler=scheduler` since `sampler` is not a defined keyword for [`StableDiffusionPipeline.__init__`]
|
||||
|
||||
@@ -337,8 +337,8 @@ model = UNet2DModel.from_pretrained(repo_id)
|
||||
|
||||
## Loading schedulers
|
||||
|
||||
Schedulers cannot be loaded via a `from_pretrained` method, but instead rely on [`ConfigMixin.from_config`]. Schedulers are **not parameterized** or **trained**, but instead purely defined by a configuration file.
|
||||
Therefore the loading method was given a different name here.
|
||||
Schedulers rely on [`SchedulerMixin.from_pretrained`]. Schedulers are **not parameterized** or **trained**, but instead purely defined by a configuration file.
|
||||
For consistency, we use the same method name as we do for models or pipelines, but no weights are loaded in this case.
|
||||
|
||||
In constrast to pipelines or models, loading schedulers does not consume any significant amount of memory and the same configuration file can often be used for a variety of different schedulers.
|
||||
For example, all of:
|
||||
@@ -367,13 +367,13 @@ from diffusers import (
|
||||
|
||||
repo_id = "runwayml/stable-diffusion-v1-5"
|
||||
|
||||
ddpm = DDPMScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
ddim = DDIMScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
pndm = PNDMScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
lms = LMSDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
euler_anc = EulerAncestralDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
euler = EulerDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
dpm = DPMSolverMultistepScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
ddpm = DDPMScheduler.from_pretrained(repo_id, subfolder="scheduler")
|
||||
ddim = DDIMScheduler.from_pretrained(repo_id, subfolder="scheduler")
|
||||
pndm = PNDMScheduler.from_pretrained(repo_id, subfolder="scheduler")
|
||||
lms = LMSDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
|
||||
euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
|
||||
euler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
|
||||
dpm = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler")
|
||||
|
||||
# replace `dpm` with any of `ddpm`, `ddim`, `pndm`, `lms`, `euler`, `euler_anc`
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(repo_id, scheduler=dpm)
|
||||
|
||||
262
docs/source/using-diffusers/schedulers.mdx
Normal file
262
docs/source/using-diffusers/schedulers.mdx
Normal file
@@ -0,0 +1,262 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Schedulers
|
||||
|
||||
Diffusion pipelines are inherently a collection of diffusion models and schedulers that are partly independent from each other. This means that one is able to switch out parts of the pipeline to better customize
|
||||
a pipeline to one's use case. The best example of this are the [Schedulers](../api/schedulers.mdx).
|
||||
|
||||
Whereas diffusion models usually simply define the forward pass from noise to a less noisy sample,
|
||||
schedulers define the whole denoising process, *i.e.*:
|
||||
- How many denoising steps?
|
||||
- Stochastic or deterministic?
|
||||
- What algorithm to use to find the denoised sample
|
||||
|
||||
They can be quite complex and often define a trade-off between **denoising speed** and **denoising quality**.
|
||||
It is extremely difficult to measure quantitatively which scheduler works best for a given diffusion pipeline, so it is often recommended to simply try out which works best.
|
||||
|
||||
The following paragraphs shows how to do so with the 🧨 Diffusers library.
|
||||
|
||||
## Load pipeline
|
||||
|
||||
Let's start by loading the stable diffusion pipeline.
|
||||
Remember that you have to be a registered user on the 🤗 Hugging Face Hub, and have "click-accepted" the [license](https://huggingface.co/runwayml/stable-diffusion-v1-5) in order to use stable diffusion.
|
||||
|
||||
```python
|
||||
from huggingface_hub import login
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
# first we need to login with our access token
|
||||
login()
|
||||
|
||||
# Now we can download the pipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
Next, we move it to GPU:
|
||||
|
||||
```python
|
||||
pipeline.to("cuda")
|
||||
```
|
||||
|
||||
## Access the scheduler
|
||||
|
||||
The scheduler is always one of the components of the pipeline and is usually called `"scheduler"`.
|
||||
So it can be accessed via the `"scheduler"` property.
|
||||
|
||||
```python
|
||||
pipeline.scheduler
|
||||
```
|
||||
|
||||
**Output**:
|
||||
```
|
||||
PNDMScheduler {
|
||||
"_class_name": "PNDMScheduler",
|
||||
"_diffusers_version": "0.8.0.dev0",
|
||||
"beta_end": 0.012,
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
"clip_sample": false,
|
||||
"num_train_timesteps": 1000,
|
||||
"set_alpha_to_one": false,
|
||||
"skip_prk_steps": true,
|
||||
"steps_offset": 1,
|
||||
"trained_betas": null
|
||||
}
|
||||
```
|
||||
|
||||
We can see that the scheduler is of type [`PNDMScheduler`].
|
||||
Cool, now let's compare the scheduler in its performance to other schedulers.
|
||||
First we define a prompt on which we will test all the different schedulers:
|
||||
|
||||
```python
|
||||
prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition."
|
||||
```
|
||||
|
||||
Next, we create a generator from a random seed that will ensure that we can generate similar images as well as run the pipeline:
|
||||
|
||||
```python
|
||||
generator = torch.Generator(device="cuda").manual_seed(8)
|
||||
image = pipeline(prompt, generator=generator).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
<br>
|
||||
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_pndm.png" width="400"/>
|
||||
<br>
|
||||
</p>
|
||||
|
||||
|
||||
## Changing the scheduler
|
||||
|
||||
Now we show how easy it is to change the scheduler of a pipeline. Every scheduler has a property [`SchedulerMixin.compatibles`]
|
||||
which defines all compatible schedulers. You can take a look at all available, compatible schedulers for the Stable Diffusion pipeline as follows.
|
||||
|
||||
```python
|
||||
pipeline.scheduler.compatibles
|
||||
```
|
||||
|
||||
**Output**:
|
||||
```
|
||||
[diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,
|
||||
diffusers.schedulers.scheduling_ddim.DDIMScheduler,
|
||||
diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,
|
||||
diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,
|
||||
diffusers.schedulers.scheduling_pndm.PNDMScheduler,
|
||||
diffusers.schedulers.scheduling_ddpm.DDPMScheduler,
|
||||
diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler]
|
||||
```
|
||||
|
||||
Cool, lots of schedulers to look at. Feel free to have a look at their respective class definitions:
|
||||
|
||||
- [`LMSDiscreteScheduler`],
|
||||
- [`DDIMScheduler`],
|
||||
- [`DPMSolverMultistepScheduler`],
|
||||
- [`EulerDiscreteScheduler`],
|
||||
- [`PNDMScheduler`],
|
||||
- [`DDPMScheduler`],
|
||||
- [`EulerAncestralDiscreteScheduler`].
|
||||
|
||||
We will now compare the input prompt with all other schedulers. To change the scheduler of the pipeline you can make use of the
|
||||
convenient [`ConfigMixin.config`] property in combination with the [`ConfigMixin.from_config`] function.
|
||||
|
||||
```python
|
||||
pipeline.scheduler.config
|
||||
```
|
||||
|
||||
returns a dictionary of the configuration of the scheduler:
|
||||
|
||||
**Output**:
|
||||
```
|
||||
FrozenDict([('num_train_timesteps', 1000),
|
||||
('beta_start', 0.00085),
|
||||
('beta_end', 0.012),
|
||||
('beta_schedule', 'scaled_linear'),
|
||||
('trained_betas', None),
|
||||
('skip_prk_steps', True),
|
||||
('set_alpha_to_one', False),
|
||||
('steps_offset', 1),
|
||||
('_class_name', 'PNDMScheduler'),
|
||||
('_diffusers_version', '0.8.0.dev0'),
|
||||
('clip_sample', False)])
|
||||
```
|
||||
|
||||
This configuration can then be used to instantiate a scheduler
|
||||
of a different class that is compatible with the pipeline. Here,
|
||||
we change the scheduler to the [`DDIMScheduler`].
|
||||
|
||||
```python
|
||||
from diffusers import DDIMScheduler
|
||||
|
||||
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
|
||||
```
|
||||
|
||||
Cool, now we can run the pipeline again to compare the generation quality.
|
||||
|
||||
```python
|
||||
generator = torch.Generator(device="cuda").manual_seed(8)
|
||||
image = pipeline(prompt, generator=generator).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
<br>
|
||||
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_ddim.png" width="400"/>
|
||||
<br>
|
||||
</p>
|
||||
|
||||
|
||||
## Compare schedulers
|
||||
|
||||
So far we have tried running the stable diffusion pipeline with two schedulers: [`PNDMScheduler`] and [`DDIMScheduler`].
|
||||
A number of better schedulers have been released that can be run with much fewer steps, let's compare them here:
|
||||
|
||||
[`LMSDiscreteScheduler`] usually leads to better results:
|
||||
|
||||
```python
|
||||
from diffusers import LMSDiscreteScheduler
|
||||
|
||||
pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
generator = torch.Generator(device="cuda").manual_seed(8)
|
||||
image = pipeline(prompt, generator=generator).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
<br>
|
||||
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_lms.png" width="400"/>
|
||||
<br>
|
||||
</p>
|
||||
|
||||
|
||||
[`EulerDiscreteScheduler`] and [`EulerAncestralDiscreteScheduler`] can generate high quality results with as little as 30 steps.
|
||||
|
||||
```python
|
||||
from diffusers import EulerDiscreteScheduler
|
||||
|
||||
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
generator = torch.Generator(device="cuda").manual_seed(8)
|
||||
image = pipeline(prompt, generator=generator, num_inference_steps=30).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
<br>
|
||||
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_euler_discrete.png" width="400"/>
|
||||
<br>
|
||||
</p>
|
||||
|
||||
|
||||
and:
|
||||
|
||||
```python
|
||||
from diffusers import EulerAncestralDiscreteScheduler
|
||||
|
||||
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
generator = torch.Generator(device="cuda").manual_seed(8)
|
||||
image = pipeline(prompt, generator=generator, num_inference_steps=30).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
<br>
|
||||
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_euler_ancestral.png" width="400"/>
|
||||
<br>
|
||||
</p>
|
||||
|
||||
|
||||
At the time of writing this doc [`DPMSolverMultistepScheduler`] gives arguably the best speed/quality trade-off and can be run with as little
|
||||
as 20 steps.
|
||||
|
||||
```python
|
||||
from diffusers import DPMSolverMultistepScheduler
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
generator = torch.Generator(device="cuda").manual_seed(8)
|
||||
image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
<br>
|
||||
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_dpm.png" width="400"/>
|
||||
<br>
|
||||
</p>
|
||||
|
||||
As you can see most images look very similar and are arguably of very similar quality. It often really depends on the specific use case which scheduler to choose. A good approach is always to run multiple different
|
||||
schedulers to compare results.
|
||||
@@ -29,7 +29,7 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R
|
||||
from requests import HTTPError
|
||||
|
||||
from . import __version__
|
||||
from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
|
||||
from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -37,6 +37,38 @@ logger = logging.get_logger(__name__)
|
||||
_re_configuration_file = re.compile(r"config\.(.*)\.json")
|
||||
|
||||
|
||||
class FrozenDict(OrderedDict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
for key, value in self.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
self.__frozen = True
|
||||
|
||||
def __delitem__(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
def setdefault(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
def pop(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if hasattr(self, "__frozen") and self.__frozen:
|
||||
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def __setitem__(self, name, value):
|
||||
if hasattr(self, "__frozen") and self.__frozen:
|
||||
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
||||
super().__setitem__(name, value)
|
||||
|
||||
|
||||
class ConfigMixin:
|
||||
r"""
|
||||
Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
|
||||
@@ -49,13 +81,12 @@ class ConfigMixin:
|
||||
[`~ConfigMixin.save_config`] (should be overridden by parent class).
|
||||
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
|
||||
overridden by parent class).
|
||||
- **_compatible_classes** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
|
||||
`from_config` can be used from a class different than the one used to save the config (should be overridden
|
||||
by parent class).
|
||||
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by parent
|
||||
class).
|
||||
"""
|
||||
config_name = None
|
||||
ignore_for_config = []
|
||||
_compatible_classes = []
|
||||
has_compatibles = False
|
||||
|
||||
def register_to_config(self, **kwargs):
|
||||
if self.config_name is None:
|
||||
@@ -104,9 +135,98 @@ class ConfigMixin:
|
||||
logger.info(f"Configuration saved in {output_config_file}")
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
|
||||
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
|
||||
r"""
|
||||
Instantiate a Python class from a pre-defined JSON-file.
|
||||
Instantiate a Python class from a config dictionary
|
||||
|
||||
Parameters:
|
||||
config (`Dict[str, Any]`):
|
||||
A config dictionary from which the Python class will be instantiated. Make sure to only load
|
||||
configuration files of compatible classes.
|
||||
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
||||
Whether kwargs that are not consumed by the Python class should be returned or not.
|
||||
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the Python class.
|
||||
`**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually
|
||||
overwrite same named arguments of `config`.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
|
||||
|
||||
>>> # Download scheduler from huggingface.co and cache.
|
||||
>>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
|
||||
|
||||
>>> # Instantiate DDIM scheduler class with same config as DDPM
|
||||
>>> scheduler = DDIMScheduler.from_config(scheduler.config)
|
||||
|
||||
>>> # Instantiate PNDM scheduler class with same config as DDPM
|
||||
>>> scheduler = PNDMScheduler.from_config(scheduler.config)
|
||||
```
|
||||
"""
|
||||
# <===== TO BE REMOVED WITH DEPRECATION
|
||||
# TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
|
||||
if "pretrained_model_name_or_path" in kwargs:
|
||||
config = kwargs.pop("pretrained_model_name_or_path")
|
||||
|
||||
if config is None:
|
||||
raise ValueError("Please make sure to provide a config as the first positional argument.")
|
||||
# ======>
|
||||
|
||||
if not isinstance(config, dict):
|
||||
deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
|
||||
if "Scheduler" in cls.__name__:
|
||||
deprecation_message += (
|
||||
f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
|
||||
" Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
|
||||
" be removed in v1.0.0."
|
||||
)
|
||||
elif "Model" in cls.__name__:
|
||||
deprecation_message += (
|
||||
f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
|
||||
f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
|
||||
" instead. This functionality will be removed in v1.0.0."
|
||||
)
|
||||
deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
|
||||
config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
|
||||
|
||||
init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
|
||||
|
||||
# Allow dtype to be specified on initialization
|
||||
if "dtype" in unused_kwargs:
|
||||
init_dict["dtype"] = unused_kwargs.pop("dtype")
|
||||
|
||||
# Return model and optionally state and/or unused_kwargs
|
||||
model = cls(**init_dict)
|
||||
|
||||
# make sure to also save config parameters that might be used for compatible classes
|
||||
model.register_to_config(**hidden_dict)
|
||||
|
||||
# add hidden kwargs of compatible classes to unused_kwargs
|
||||
unused_kwargs = {**unused_kwargs, **hidden_dict}
|
||||
|
||||
if return_unused_kwargs:
|
||||
return (model, unused_kwargs)
|
||||
else:
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def get_config_dict(cls, *args, **kwargs):
|
||||
deprecation_message = (
|
||||
f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
|
||||
" removed in version v1.0.0"
|
||||
)
|
||||
deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
|
||||
return cls.load_config(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def load_config(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
r"""
|
||||
Instantiate a Python class from a config dictionary
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
||||
@@ -120,10 +240,6 @@ class ConfigMixin:
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
|
||||
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
|
||||
checkpoint with 3 labels).
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
@@ -161,33 +277,7 @@ class ConfigMixin:
|
||||
use this method in a firewalled environment.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
|
||||
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
# Allow dtype to be specified on initialization
|
||||
if "dtype" in unused_kwargs:
|
||||
init_dict["dtype"] = unused_kwargs.pop("dtype")
|
||||
|
||||
# Return model and optionally state and/or unused_kwargs
|
||||
model = cls(**init_dict)
|
||||
return_tuple = (model,)
|
||||
|
||||
# Flax schedulers have a state, so return it.
|
||||
if cls.__name__.startswith("Flax") and hasattr(model, "create_state") and getattr(model, "has_state", False):
|
||||
state = model.create_state()
|
||||
return_tuple += (state,)
|
||||
|
||||
if return_unused_kwargs:
|
||||
return return_tuple + (unused_kwargs,)
|
||||
else:
|
||||
return return_tuple if len(return_tuple) > 1 else model
|
||||
|
||||
@classmethod
|
||||
def get_config_dict(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
@@ -283,6 +373,9 @@ class ConfigMixin:
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
|
||||
|
||||
if return_unused_kwargs:
|
||||
return config_dict, kwargs
|
||||
|
||||
return config_dict
|
||||
|
||||
@staticmethod
|
||||
@@ -291,6 +384,9 @@ class ConfigMixin:
|
||||
|
||||
@classmethod
|
||||
def extract_init_dict(cls, config_dict, **kwargs):
|
||||
# 0. Copy origin config dict
|
||||
original_dict = {k: v for k, v in config_dict.items()}
|
||||
|
||||
# 1. Retrieve expected config attributes from __init__ signature
|
||||
expected_keys = cls._get_init_keys(cls)
|
||||
expected_keys.remove("self")
|
||||
@@ -310,10 +406,11 @@ class ConfigMixin:
|
||||
# load diffusers library to import compatible and original scheduler
|
||||
diffusers_library = importlib.import_module(__name__.split(".")[0])
|
||||
|
||||
# remove attributes from compatible classes that orig cannot expect
|
||||
compatible_classes = [getattr(diffusers_library, c, None) for c in cls._compatible_classes]
|
||||
# filter out None potentially undefined dummy classes
|
||||
compatible_classes = [c for c in compatible_classes if c is not None]
|
||||
if cls.has_compatibles:
|
||||
compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
|
||||
else:
|
||||
compatible_classes = []
|
||||
|
||||
expected_keys_comp_cls = set()
|
||||
for c in compatible_classes:
|
||||
expected_keys_c = cls._get_init_keys(c)
|
||||
@@ -364,7 +461,10 @@ class ConfigMixin:
|
||||
# 6. Define unused keyword arguments
|
||||
unused_kwargs = {**config_dict, **kwargs}
|
||||
|
||||
return init_dict, unused_kwargs
|
||||
# 7. Define "hidden" config parameters that were saved for compatible classes
|
||||
hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict and not k.startswith("_")}
|
||||
|
||||
return init_dict, unused_kwargs, hidden_config_dict
|
||||
|
||||
@classmethod
|
||||
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
||||
@@ -377,6 +477,12 @@ class ConfigMixin:
|
||||
|
||||
@property
|
||||
def config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns the config of the class as a frozen dictionary
|
||||
|
||||
Returns:
|
||||
`Dict[str, Any]`: Config of the class.
|
||||
"""
|
||||
return self._internal_dict
|
||||
|
||||
def to_json_string(self) -> str:
|
||||
@@ -401,38 +507,6 @@ class ConfigMixin:
|
||||
writer.write(self.to_json_string())
|
||||
|
||||
|
||||
class FrozenDict(OrderedDict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
for key, value in self.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
self.__frozen = True
|
||||
|
||||
def __delitem__(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
def setdefault(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
def pop(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if hasattr(self, "__frozen") and self.__frozen:
|
||||
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def __setitem__(self, name, value):
|
||||
if hasattr(self, "__frozen") and self.__frozen:
|
||||
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
||||
super().__setitem__(name, value)
|
||||
|
||||
|
||||
def register_to_config(init):
|
||||
r"""
|
||||
Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
|
||||
|
||||
@@ -47,7 +47,7 @@ logger = logging.get_logger(__name__)
|
||||
LOADABLE_CLASSES = {
|
||||
"diffusers": {
|
||||
"FlaxModelMixin": ["save_pretrained", "from_pretrained"],
|
||||
"FlaxSchedulerMixin": ["save_config", "from_config"],
|
||||
"FlaxSchedulerMixin": ["save_pretrained", "from_pretrained"],
|
||||
"FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"],
|
||||
},
|
||||
"transformers": {
|
||||
@@ -280,7 +280,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
>>> from diffusers import FlaxDPMSolverMultistepScheduler
|
||||
|
||||
>>> model_id = "runwayml/stable-diffusion-v1-5"
|
||||
>>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_config(
|
||||
>>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
|
||||
... model_id,
|
||||
... subfolder="scheduler",
|
||||
... )
|
||||
@@ -303,7 +303,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
# 1. Download the checkpoints and configs
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
if not os.path.isdir(pretrained_model_name_or_path):
|
||||
config_dict = cls.get_config_dict(
|
||||
config_dict = cls.load_config(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
@@ -349,7 +349,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
else:
|
||||
cached_folder = pretrained_model_name_or_path
|
||||
|
||||
config_dict = cls.get_config_dict(cached_folder)
|
||||
config_dict = cls.load_config(cached_folder)
|
||||
|
||||
# 2. Load the pipeline class, if using custom module then load it from the hub
|
||||
# if we load from explicit class, let's use it
|
||||
@@ -370,7 +370,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
|
||||
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
init_kwargs = {}
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ logger = logging.get_logger(__name__)
|
||||
LOADABLE_CLASSES = {
|
||||
"diffusers": {
|
||||
"ModelMixin": ["save_pretrained", "from_pretrained"],
|
||||
"SchedulerMixin": ["save_config", "from_config"],
|
||||
"SchedulerMixin": ["save_pretrained", "from_pretrained"],
|
||||
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
|
||||
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
|
||||
},
|
||||
@@ -207,7 +207,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
if torch_device is None:
|
||||
return self
|
||||
|
||||
module_names, _ = self.extract_init_dict(dict(self.config))
|
||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||
for name in module_names.keys():
|
||||
module = getattr(self, name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
@@ -228,7 +228,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
Returns:
|
||||
`torch.device`: The torch device on which the pipeline is located.
|
||||
"""
|
||||
module_names, _ = self.extract_init_dict(dict(self.config))
|
||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||
for name in module_names.keys():
|
||||
module = getattr(self, name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
@@ -377,11 +377,11 @@ class DiffusionPipeline(ConfigMixin):
|
||||
>>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
|
||||
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
|
||||
>>> # Download pipeline, but overwrite scheduler
|
||||
>>> # Use a different scheduler
|
||||
>>> from diffusers import LMSDiscreteScheduler
|
||||
|
||||
>>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
|
||||
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler)
|
||||
>>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
>>> pipeline.scheduler = scheduler
|
||||
```
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
@@ -428,7 +428,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
# 1. Download the checkpoints and configs
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
if not os.path.isdir(pretrained_model_name_or_path):
|
||||
config_dict = cls.get_config_dict(
|
||||
config_dict = cls.load_config(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
@@ -474,7 +474,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
else:
|
||||
cached_folder = pretrained_model_name_or_path
|
||||
|
||||
config_dict = cls.get_config_dict(cached_folder)
|
||||
config_dict = cls.load_config(cached_folder)
|
||||
|
||||
# 2. Load the pipeline class, if using custom module then load it from the hub
|
||||
# if we load from explicit class, let's use it
|
||||
@@ -513,7 +513,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"])
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
|
||||
init_dict, unused_kwargs = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(f"Keyword arguments {unused_kwargs} not recognized.")
|
||||
|
||||
@@ -71,7 +71,7 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
"""
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
|
||||
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ image.save("astronaut_rides_horse.png")
|
||||
# make sure you're logged in with `huggingface-cli login`
|
||||
from diffusers import StableDiffusionPipeline, DDIMScheduler
|
||||
|
||||
scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||
scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
@@ -91,7 +91,7 @@ image.save("astronaut_rides_horse.png")
|
||||
# make sure you're logged in with `huggingface-cli login`
|
||||
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
|
||||
|
||||
lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||
lms = LMSDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
@@ -120,7 +120,7 @@ from diffusers import CycleDiffusionPipeline, DDIMScheduler
|
||||
# load the pipeline
|
||||
# make sure you're logged in with `huggingface-cli login`
|
||||
model_id_or_path = "CompVis/stable-diffusion-v1-4"
|
||||
scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler")
|
||||
scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler")
|
||||
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")
|
||||
|
||||
# let's download an initial image
|
||||
|
||||
@@ -23,7 +23,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@@ -82,8 +82,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2010.02502
|
||||
|
||||
@@ -109,14 +109,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatible_classes = [
|
||||
"PNDMScheduler",
|
||||
"DDPMScheduler",
|
||||
"LMSDiscreteScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
]
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -23,7 +23,12 @@ import flax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
|
||||
from .scheduling_utils_flax import (
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
|
||||
FlaxSchedulerMixin,
|
||||
FlaxSchedulerOutput,
|
||||
broadcast_to_shape_from_left,
|
||||
)
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
|
||||
@@ -79,8 +84,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2010.02502
|
||||
|
||||
@@ -105,6 +110,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
stable diffusion.
|
||||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
return True
|
||||
|
||||
@@ -22,7 +22,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
||||
from ..utils import BaseOutput, deprecate
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@@ -80,8 +80,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2006.11239
|
||||
|
||||
@@ -104,14 +104,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatible_classes = [
|
||||
"DDIMScheduler",
|
||||
"PNDMScheduler",
|
||||
"LMSDiscreteScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
]
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -249,7 +242,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
|
||||
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
|
||||
|
||||
@@ -24,7 +24,12 @@ from jax import random
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
||||
from ..utils import deprecate
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
|
||||
from .scheduling_utils_flax import (
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
|
||||
FlaxSchedulerMixin,
|
||||
FlaxSchedulerOutput,
|
||||
broadcast_to_shape_from_left,
|
||||
)
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
|
||||
@@ -79,8 +84,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2006.11239
|
||||
|
||||
@@ -103,6 +108,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
return True
|
||||
@@ -221,7 +228,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
|
||||
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
|
||||
|
||||
@@ -21,6 +21,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
@@ -71,8 +72,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
@@ -116,14 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatible_classes = [
|
||||
"DDIMScheduler",
|
||||
"DDPMScheduler",
|
||||
"PNDMScheduler",
|
||||
"LMSDiscreteScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
]
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -23,7 +23,12 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
|
||||
from .scheduling_utils_flax import (
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
|
||||
FlaxSchedulerMixin,
|
||||
FlaxSchedulerOutput,
|
||||
broadcast_to_shape_from_left,
|
||||
)
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
|
||||
@@ -96,8 +101,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
|
||||
|
||||
@@ -143,6 +148,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
return True
|
||||
|
||||
@@ -19,7 +19,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, logging
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@@ -52,8 +52,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
@@ -67,14 +67,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatible_classes = [
|
||||
"DDIMScheduler",
|
||||
"DDPMScheduler",
|
||||
"LMSDiscreteScheduler",
|
||||
"PNDMScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
]
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -19,7 +19,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, logging
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@@ -53,8 +53,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
@@ -68,14 +68,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatible_classes = [
|
||||
"DDIMScheduler",
|
||||
"DDPMScheduler",
|
||||
"LMSDiscreteScheduler",
|
||||
"PNDMScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
]
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -28,8 +28,8 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2202.09778
|
||||
|
||||
|
||||
@@ -56,8 +56,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
|
||||
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
|
||||
|
||||
@@ -67,8 +67,8 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
|
||||
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch
|
||||
from scipy import integrate
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@@ -52,8 +52,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
@@ -67,14 +67,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatible_classes = [
|
||||
"DDIMScheduler",
|
||||
"DDPMScheduler",
|
||||
"PNDMScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
]
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -20,7 +20,12 @@ import jax.numpy as jnp
|
||||
from scipy import integrate
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
|
||||
from .scheduling_utils_flax import (
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
|
||||
FlaxSchedulerMixin,
|
||||
FlaxSchedulerOutput,
|
||||
broadcast_to_shape_from_left,
|
||||
)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
@@ -49,8 +54,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
@@ -63,6 +68,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
return True
|
||||
|
||||
@@ -21,6 +21,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
@@ -60,8 +61,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2202.09778
|
||||
|
||||
@@ -88,14 +89,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatible_classes = [
|
||||
"DDIMScheduler",
|
||||
"DDPMScheduler",
|
||||
"LMSDiscreteScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
]
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -23,7 +23,12 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
|
||||
from .scheduling_utils_flax import (
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
|
||||
FlaxSchedulerMixin,
|
||||
FlaxSchedulerOutput,
|
||||
broadcast_to_shape_from_left,
|
||||
)
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
|
||||
@@ -87,8 +92,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2202.09778
|
||||
|
||||
@@ -114,6 +119,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
stable diffusion.
|
||||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
return True
|
||||
|
||||
@@ -77,8 +77,8 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/pdf/2201.09865.pdf
|
||||
|
||||
|
||||
@@ -50,8 +50,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
|
||||
@@ -64,8 +64,8 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
|
||||
@@ -29,8 +29,8 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more information, see the original paper: https://arxiv.org/abs/2011.13456
|
||||
|
||||
|
||||
@@ -11,7 +11,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -38,6 +41,114 @@ class SchedulerOutput(BaseOutput):
|
||||
class SchedulerMixin:
|
||||
"""
|
||||
Mixin containing common functions for the schedulers.
|
||||
|
||||
Class attributes:
|
||||
- **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
|
||||
`from_config` can be used from a class different than the one used to save the config (should be overridden
|
||||
by parent class).
|
||||
"""
|
||||
|
||||
config_name = SCHEDULER_CONFIG_NAME
|
||||
_compatibles = []
|
||||
has_compatibles = True
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Dict[str, Any] = None,
|
||||
subfolder: Optional[str] = None,
|
||||
return_unused_kwargs=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Instantiate a Scheduler class from a pre-defined JSON configuration file inside a directory or Hub repo.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
|
||||
organization name, like `google/ddpm-celebahq-256`.
|
||||
- A path to a *directory* containing the schedluer configurations saved using
|
||||
[`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`.
|
||||
subfolder (`str`, *optional*):
|
||||
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
||||
huggingface.co or downloaded locally), you can specify the folder name here.
|
||||
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
||||
Whether kwargs that are not consumed by the Python class should be returned or not.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
||||
file exists.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
local_files_only(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to only look at local files (i.e., do not try to download the model).
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `transformers-cli login` (stored in `~/.huggingface`).
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
|
||||
<Tip>
|
||||
|
||||
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
||||
models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
|
||||
use this method in a firewalled environment.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
config, kwargs = cls.load_config(
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
subfolder=subfolder,
|
||||
return_unused_kwargs=True,
|
||||
**kwargs,
|
||||
)
|
||||
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
||||
"""
|
||||
Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
|
||||
[`~SchedulerMixin.from_pretrained`] class method.
|
||||
|
||||
Args:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
||||
"""
|
||||
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||||
|
||||
@property
|
||||
def compatibles(self):
|
||||
"""
|
||||
Returns all schedulers that are compatible with this scheduler
|
||||
|
||||
Returns:
|
||||
`List[SchedulerMixin]`: List of compatible schedulers
|
||||
"""
|
||||
return self._get_compatibles()
|
||||
|
||||
@classmethod
|
||||
def _get_compatibles(cls):
|
||||
compatible_classes_str = list(set([cls.__name__] + cls._compatibles))
|
||||
diffusers_library = importlib.import_module(__name__.split(".")[0])
|
||||
compatible_classes = [
|
||||
getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
|
||||
]
|
||||
return compatible_classes
|
||||
|
||||
@@ -11,15 +11,18 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..utils import BaseOutput
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
|
||||
|
||||
|
||||
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = ["Flax" + c for c in _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -39,9 +42,123 @@ class FlaxSchedulerOutput(BaseOutput):
|
||||
class FlaxSchedulerMixin:
|
||||
"""
|
||||
Mixin containing common functions for the schedulers.
|
||||
|
||||
Class attributes:
|
||||
- **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
|
||||
`from_config` can be used from a class different than the one used to save the config (should be overridden
|
||||
by parent class).
|
||||
"""
|
||||
|
||||
config_name = SCHEDULER_CONFIG_NAME
|
||||
_compatibles = []
|
||||
has_compatibles = True
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Dict[str, Any] = None,
|
||||
subfolder: Optional[str] = None,
|
||||
return_unused_kwargs=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Instantiate a Scheduler class from a pre-defined JSON-file.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
|
||||
organization name, like `google/ddpm-celebahq-256`.
|
||||
- A path to a *directory* containing model weights saved using [`~SchedulerMixin.save_pretrained`],
|
||||
e.g., `./my_model_directory/`.
|
||||
subfolder (`str`, *optional*):
|
||||
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
||||
huggingface.co or downloaded locally), you can specify the folder name here.
|
||||
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
||||
Whether kwargs that are not consumed by the Python class should be returned or not.
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
||||
file exists.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
local_files_only(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to only look at local files (i.e., do not try to download the model).
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `transformers-cli login` (stored in `~/.huggingface`).
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
|
||||
<Tip>
|
||||
|
||||
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
||||
models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
|
||||
use this method in a firewalled environment.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
config, kwargs = cls.load_config(
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
scheduler, unused_kwargs = cls.from_config(config, return_unused_kwargs=True, **kwargs)
|
||||
|
||||
if hasattr(scheduler, "create_state") and getattr(scheduler, "has_state", False):
|
||||
state = scheduler.create_state()
|
||||
|
||||
if return_unused_kwargs:
|
||||
return scheduler, state, unused_kwargs
|
||||
|
||||
return scheduler, state
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
||||
"""
|
||||
Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
|
||||
[`~FlaxSchedulerMixin.from_pretrained`] class method.
|
||||
|
||||
Args:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
||||
"""
|
||||
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||||
|
||||
@property
|
||||
def compatibles(self):
|
||||
"""
|
||||
Returns all schedulers that are compatible with this scheduler
|
||||
|
||||
Returns:
|
||||
`List[SchedulerMixin]`: List of compatible schedulers
|
||||
"""
|
||||
return self._get_compatibles()
|
||||
|
||||
@classmethod
|
||||
def _get_compatibles(cls):
|
||||
compatible_classes_str = list(set([cls.__name__] + cls._compatibles))
|
||||
diffusers_library = importlib.import_module(__name__.split(".")[0])
|
||||
compatible_classes = [
|
||||
getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
|
||||
]
|
||||
return compatible_classes
|
||||
|
||||
|
||||
def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray:
|
||||
|
||||
@@ -112,8 +112,8 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2111.14822
|
||||
|
||||
|
||||
@@ -72,3 +72,13 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
|
||||
DIFFUSERS_CACHE = default_cache_path
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
|
||||
|
||||
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [
|
||||
"DDIMScheduler",
|
||||
"DDPMScheduler",
|
||||
"PNDMScheduler",
|
||||
"LMSDiscreteScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
]
|
||||
|
||||
@@ -67,8 +67,8 @@ class UNet1DModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
super().test_from_pretrained_save_pretrained()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_model_from_config(self):
|
||||
super().test_model_from_config()
|
||||
def test_model_from_pretrained(self):
|
||||
super().test_model_from_pretrained()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_output(self):
|
||||
@@ -187,8 +187,8 @@ class UNetRLModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
super().test_from_pretrained_save_pretrained()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_model_from_config(self):
|
||||
super().test_model_from_config()
|
||||
def test_model_from_pretrained(self):
|
||||
super().test_model_from_pretrained()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_output(self):
|
||||
|
||||
@@ -75,7 +75,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
|
||||
model_id = "google/ddpm-ema-bedroom-256"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = DDIMScheduler.from_config(model_id)
|
||||
scheduler = DDIMScheduler.from_pretrained(model_id)
|
||||
|
||||
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||
ddpm.to(torch_device)
|
||||
|
||||
@@ -106,7 +106,7 @@ class DDPMPipelineIntegrationTests(unittest.TestCase):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = DDPMScheduler.from_config(model_id)
|
||||
scheduler = DDPMScheduler.from_pretrained(model_id)
|
||||
|
||||
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
|
||||
ddpm.to(torch_device)
|
||||
|
||||
@@ -44,7 +44,7 @@ class RepaintPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
model_id = "google/ddpm-ema-celebahq-256"
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = RePaintScheduler.from_config(model_id)
|
||||
scheduler = RePaintScheduler.from_pretrained(model_id)
|
||||
|
||||
repaint = RePaintPipeline(unet=unet, scheduler=scheduler).to(torch_device)
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ class ScoreSdeVePipelineIntegrationTests(unittest.TestCase):
|
||||
model_id = "google/ncsnpp-church-256"
|
||||
model = UNet2DModel.from_pretrained(model_id)
|
||||
|
||||
scheduler = ScoreSdeVeScheduler.from_config(model_id)
|
||||
scheduler = ScoreSdeVeScheduler.from_pretrained(model_id)
|
||||
|
||||
sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
|
||||
sde_ve.to(torch_device)
|
||||
|
||||
@@ -281,7 +281,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
init_image = init_image.resize((512, 512))
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
scheduler = DDIMScheduler.from_config(model_id, subfolder="scheduler")
|
||||
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
pipe = CycleDiffusionPipeline.from_pretrained(
|
||||
model_id, scheduler=scheduler, safety_checker=None, torch_dtype=torch.float16, revision="fp16"
|
||||
)
|
||||
@@ -322,7 +322,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
init_image = init_image.resize((512, 512))
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
scheduler = DDIMScheduler.from_config(model_id, subfolder="scheduler")
|
||||
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
pipe = CycleDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, safety_checker=None)
|
||||
|
||||
pipe.to(torch_device)
|
||||
|
||||
@@ -75,7 +75,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_inference_ddim(self):
|
||||
ddim_scheduler = DDIMScheduler.from_config(
|
||||
ddim_scheduler = DDIMScheduler.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx"
|
||||
)
|
||||
sd_pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
||||
@@ -98,7 +98,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_inference_k_lms(self):
|
||||
lms_scheduler = LMSDiscreteScheduler.from_config(
|
||||
lms_scheduler = LMSDiscreteScheduler.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx"
|
||||
)
|
||||
sd_pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
||||
|
||||
@@ -93,7 +93,7 @@ class OnnxStableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
"/img2img/sketch-mountains-input.jpg"
|
||||
)
|
||||
init_image = init_image.resize((768, 512))
|
||||
lms_scheduler = LMSDiscreteScheduler.from_config(
|
||||
lms_scheduler = LMSDiscreteScheduler.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx"
|
||||
)
|
||||
pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
|
||||
@@ -97,7 +97,7 @@ class OnnxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
)
|
||||
lms_scheduler = LMSDiscreteScheduler.from_config(
|
||||
lms_scheduler = LMSDiscreteScheduler.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", subfolder="scheduler", revision="onnx"
|
||||
)
|
||||
pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained(
|
||||
|
||||
@@ -703,7 +703,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_fast_ddim(self):
|
||||
scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-1", subfolder="scheduler")
|
||||
scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-1", subfolder="scheduler")
|
||||
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", scheduler=scheduler)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
@@ -726,7 +726,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
model_id = "CompVis/stable-diffusion-v1-1"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_id).to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
|
||||
scheduler = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
pipe.scheduler = scheduler
|
||||
|
||||
prompt = "a photograph of an astronaut riding a horse"
|
||||
|
||||
@@ -520,7 +520,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
|
||||
lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
model_id,
|
||||
scheduler=lms,
|
||||
@@ -557,7 +557,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
ddim = DDIMScheduler.from_config(model_id, subfolder="scheduler")
|
||||
ddim = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
model_id,
|
||||
scheduler=ddim,
|
||||
@@ -649,7 +649,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
init_image = init_image.resize((768, 512))
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
|
||||
lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
model_id, scheduler=lms, safety_checker=None, device_map="auto", revision="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
@@ -400,7 +400,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
model_id = "runwayml/stable-diffusion-inpainting"
|
||||
pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler")
|
||||
pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None, scheduler=pndm)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -437,7 +437,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
model_id = "runwayml/stable-diffusion-inpainting"
|
||||
pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler")
|
||||
pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
model_id,
|
||||
safety_checker=None,
|
||||
|
||||
@@ -401,7 +401,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
|
||||
lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
model_id,
|
||||
scheduler=lms,
|
||||
|
||||
@@ -13,12 +13,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import diffusers
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
@@ -81,7 +78,7 @@ class SampleObject3(ConfigMixin):
|
||||
class ConfigTester(unittest.TestCase):
|
||||
def test_load_not_from_mixin(self):
|
||||
with self.assertRaises(ValueError):
|
||||
ConfigMixin.from_config("dummy_path")
|
||||
ConfigMixin.load_config("dummy_path")
|
||||
|
||||
def test_register_to_config(self):
|
||||
obj = SampleObject()
|
||||
@@ -131,7 +128,7 @@ class ConfigTester(unittest.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
obj.save_config(tmpdirname)
|
||||
new_obj = SampleObject.from_config(tmpdirname)
|
||||
new_obj = SampleObject.from_config(SampleObject.load_config(tmpdirname))
|
||||
new_config = new_obj.config
|
||||
|
||||
# unfreeze configs
|
||||
@@ -142,117 +139,13 @@ class ConfigTester(unittest.TestCase):
|
||||
assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json
|
||||
assert config == new_config
|
||||
|
||||
def test_save_load_from_different_config(self):
|
||||
obj = SampleObject()
|
||||
|
||||
# mock add obj class to `diffusers`
|
||||
setattr(diffusers, "SampleObject", SampleObject)
|
||||
logger = logging.get_logger("diffusers.configuration_utils")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
obj.save_config(tmpdirname)
|
||||
with CaptureLogger(logger) as cap_logger_1:
|
||||
new_obj_1 = SampleObject2.from_config(tmpdirname)
|
||||
|
||||
# now save a config parameter that is not expected
|
||||
with open(os.path.join(tmpdirname, SampleObject.config_name), "r") as f:
|
||||
data = json.load(f)
|
||||
data["unexpected"] = True
|
||||
|
||||
with open(os.path.join(tmpdirname, SampleObject.config_name), "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger_2:
|
||||
new_obj_2 = SampleObject.from_config(tmpdirname)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger_3:
|
||||
new_obj_3 = SampleObject2.from_config(tmpdirname)
|
||||
|
||||
assert new_obj_1.__class__ == SampleObject2
|
||||
assert new_obj_2.__class__ == SampleObject
|
||||
assert new_obj_3.__class__ == SampleObject2
|
||||
|
||||
assert cap_logger_1.out == ""
|
||||
assert (
|
||||
cap_logger_2.out
|
||||
== "The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will"
|
||||
" be ignored. Please verify your config.json configuration file.\n"
|
||||
)
|
||||
assert cap_logger_2.out.replace("SampleObject", "SampleObject2") == cap_logger_3.out
|
||||
|
||||
def test_save_load_compatible_schedulers(self):
|
||||
SampleObject2._compatible_classes = ["SampleObject"]
|
||||
SampleObject._compatible_classes = ["SampleObject2"]
|
||||
|
||||
obj = SampleObject()
|
||||
|
||||
# mock add obj class to `diffusers`
|
||||
setattr(diffusers, "SampleObject", SampleObject)
|
||||
setattr(diffusers, "SampleObject2", SampleObject2)
|
||||
logger = logging.get_logger("diffusers.configuration_utils")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
obj.save_config(tmpdirname)
|
||||
|
||||
# now save a config parameter that is expected by another class, but not origin class
|
||||
with open(os.path.join(tmpdirname, SampleObject.config_name), "r") as f:
|
||||
data = json.load(f)
|
||||
data["f"] = [0, 0]
|
||||
data["unexpected"] = True
|
||||
|
||||
with open(os.path.join(tmpdirname, SampleObject.config_name), "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
new_obj = SampleObject.from_config(tmpdirname)
|
||||
|
||||
assert new_obj.__class__ == SampleObject
|
||||
|
||||
assert (
|
||||
cap_logger.out
|
||||
== "The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will"
|
||||
" be ignored. Please verify your config.json configuration file.\n"
|
||||
)
|
||||
|
||||
def test_save_load_from_different_config_comp_schedulers(self):
|
||||
SampleObject3._compatible_classes = ["SampleObject", "SampleObject2"]
|
||||
SampleObject2._compatible_classes = ["SampleObject", "SampleObject3"]
|
||||
SampleObject._compatible_classes = ["SampleObject2", "SampleObject3"]
|
||||
|
||||
obj = SampleObject()
|
||||
|
||||
# mock add obj class to `diffusers`
|
||||
setattr(diffusers, "SampleObject", SampleObject)
|
||||
setattr(diffusers, "SampleObject2", SampleObject2)
|
||||
setattr(diffusers, "SampleObject3", SampleObject3)
|
||||
logger = logging.get_logger("diffusers.configuration_utils")
|
||||
logger.setLevel(diffusers.logging.INFO)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
obj.save_config(tmpdirname)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger_1:
|
||||
new_obj_1 = SampleObject.from_config(tmpdirname)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger_2:
|
||||
new_obj_2 = SampleObject2.from_config(tmpdirname)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger_3:
|
||||
new_obj_3 = SampleObject3.from_config(tmpdirname)
|
||||
|
||||
assert new_obj_1.__class__ == SampleObject
|
||||
assert new_obj_2.__class__ == SampleObject2
|
||||
assert new_obj_3.__class__ == SampleObject3
|
||||
|
||||
assert cap_logger_1.out == ""
|
||||
assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
|
||||
assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
|
||||
|
||||
def test_load_ddim_from_pndm(self):
|
||||
logger = logging.get_logger("diffusers.configuration_utils")
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
ddim = DDIMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
|
||||
ddim = DDIMScheduler.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
|
||||
)
|
||||
|
||||
assert ddim.__class__ == DDIMScheduler
|
||||
# no warning should be thrown
|
||||
@@ -262,7 +155,7 @@ class ConfigTester(unittest.TestCase):
|
||||
logger = logging.get_logger("diffusers.configuration_utils")
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
euler = EulerDiscreteScheduler.from_config(
|
||||
euler = EulerDiscreteScheduler.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
|
||||
)
|
||||
|
||||
@@ -274,7 +167,7 @@ class ConfigTester(unittest.TestCase):
|
||||
logger = logging.get_logger("diffusers.configuration_utils")
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
euler = EulerAncestralDiscreteScheduler.from_config(
|
||||
euler = EulerAncestralDiscreteScheduler.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
|
||||
)
|
||||
|
||||
@@ -286,7 +179,9 @@ class ConfigTester(unittest.TestCase):
|
||||
logger = logging.get_logger("diffusers.configuration_utils")
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
|
||||
pndm = PNDMScheduler.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
|
||||
)
|
||||
|
||||
assert pndm.__class__ == PNDMScheduler
|
||||
# no warning should be thrown
|
||||
@@ -296,7 +191,7 @@ class ConfigTester(unittest.TestCase):
|
||||
logger = logging.get_logger("diffusers.configuration_utils")
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
ddpm = DDPMScheduler.from_config(
|
||||
ddpm = DDPMScheduler.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch",
|
||||
subfolder="scheduler",
|
||||
predict_epsilon=False,
|
||||
@@ -304,7 +199,7 @@ class ConfigTester(unittest.TestCase):
|
||||
)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger_2:
|
||||
ddpm_2 = DDPMScheduler.from_config("google/ddpm-celebahq-256", beta_start=88)
|
||||
ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88)
|
||||
|
||||
assert ddpm.__class__ == DDPMScheduler
|
||||
assert ddpm.config.predict_epsilon is False
|
||||
@@ -319,7 +214,7 @@ class ConfigTester(unittest.TestCase):
|
||||
logger = logging.get_logger("diffusers.configuration_utils")
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
dpm = DPMSolverMultistepScheduler.from_config(
|
||||
dpm = DPMSolverMultistepScheduler.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
|
||||
)
|
||||
|
||||
|
||||
@@ -130,7 +130,7 @@ class ModelTesterMixin:
|
||||
expected_arg_names = ["sample", "timestep"]
|
||||
self.assertListEqual(arg_names[:2], expected_arg_names)
|
||||
|
||||
def test_model_from_config(self):
|
||||
def test_model_from_pretrained(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
@@ -140,8 +140,8 @@ class ModelTesterMixin:
|
||||
# test if the model can be loaded from the config
|
||||
# and has all the expected shape
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_config(tmpdirname)
|
||||
new_model = self.model_class.from_config(tmpdirname)
|
||||
model.save_pretrained(tmpdirname)
|
||||
new_model = self.model_class.from_pretrained(tmpdirname)
|
||||
new_model.to(torch_device)
|
||||
new_model.eval()
|
||||
|
||||
|
||||
@@ -29,6 +29,10 @@ from diffusers import (
|
||||
DDIMScheduler,
|
||||
DDPMPipeline,
|
||||
DDPMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipelineLegacy,
|
||||
@@ -398,6 +402,82 @@ class PipelineFastTests(unittest.TestCase):
|
||||
assert image_img2img.shape == (1, 32, 32, 3)
|
||||
assert image_text2img.shape == (1, 128, 128, 3)
|
||||
|
||||
def test_set_scheduler(self):
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
sd = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
|
||||
sd.scheduler = DDIMScheduler.from_config(sd.scheduler.config)
|
||||
assert isinstance(sd.scheduler, DDIMScheduler)
|
||||
sd.scheduler = DDPMScheduler.from_config(sd.scheduler.config)
|
||||
assert isinstance(sd.scheduler, DDPMScheduler)
|
||||
sd.scheduler = PNDMScheduler.from_config(sd.scheduler.config)
|
||||
assert isinstance(sd.scheduler, PNDMScheduler)
|
||||
sd.scheduler = LMSDiscreteScheduler.from_config(sd.scheduler.config)
|
||||
assert isinstance(sd.scheduler, LMSDiscreteScheduler)
|
||||
sd.scheduler = EulerDiscreteScheduler.from_config(sd.scheduler.config)
|
||||
assert isinstance(sd.scheduler, EulerDiscreteScheduler)
|
||||
sd.scheduler = EulerAncestralDiscreteScheduler.from_config(sd.scheduler.config)
|
||||
assert isinstance(sd.scheduler, EulerAncestralDiscreteScheduler)
|
||||
sd.scheduler = DPMSolverMultistepScheduler.from_config(sd.scheduler.config)
|
||||
assert isinstance(sd.scheduler, DPMSolverMultistepScheduler)
|
||||
|
||||
def test_set_scheduler_consistency(self):
|
||||
unet = self.dummy_cond_unet
|
||||
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
|
||||
ddim = DDIMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
sd = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=pndm,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
|
||||
pndm_config = sd.scheduler.config
|
||||
sd.scheduler = DDPMScheduler.from_config(pndm_config)
|
||||
sd.scheduler = PNDMScheduler.from_config(sd.scheduler.config)
|
||||
pndm_config_2 = sd.scheduler.config
|
||||
pndm_config_2 = {k: v for k, v in pndm_config_2.items() if k in pndm_config}
|
||||
|
||||
assert dict(pndm_config) == dict(pndm_config_2)
|
||||
|
||||
sd = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=ddim,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
|
||||
ddim_config = sd.scheduler.config
|
||||
sd.scheduler = LMSDiscreteScheduler.from_config(ddim_config)
|
||||
sd.scheduler = DDIMScheduler.from_config(sd.scheduler.config)
|
||||
ddim_config_2 = sd.scheduler.config
|
||||
ddim_config_2 = {k: v for k, v in ddim_config_2.items() if k in ddim_config}
|
||||
|
||||
assert dict(ddim_config) == dict(ddim_config_2)
|
||||
|
||||
|
||||
@slow
|
||||
class PipelineSlowTests(unittest.TestCase):
|
||||
@@ -519,7 +599,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
def test_output_format(self):
|
||||
model_path = "google/ddpm-cifar10-32"
|
||||
|
||||
scheduler = DDIMScheduler.from_config(model_path)
|
||||
scheduler = DDIMScheduler.from_pretrained(model_path)
|
||||
pipe = DDIMPipeline.from_pretrained(model_path, scheduler=scheduler)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Dict, List, Tuple
|
||||
@@ -21,6 +23,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import diffusers
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
@@ -32,13 +35,180 @@ from diffusers import (
|
||||
PNDMScheduler,
|
||||
ScoreSdeVeScheduler,
|
||||
VQDiffusionScheduler,
|
||||
logging,
|
||||
)
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from diffusers.utils import deprecate, torch_device
|
||||
from diffusers.utils.testing_utils import CaptureLogger
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class SchedulerObject(SchedulerMixin, ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
a=2,
|
||||
b=5,
|
||||
c=(2, 5),
|
||||
d="for diffusion",
|
||||
e=[1, 3],
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class SchedulerObject2(SchedulerMixin, ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
a=2,
|
||||
b=5,
|
||||
c=(2, 5),
|
||||
d="for diffusion",
|
||||
f=[1, 3],
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class SchedulerObject3(SchedulerMixin, ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
a=2,
|
||||
b=5,
|
||||
c=(2, 5),
|
||||
d="for diffusion",
|
||||
e=[1, 3],
|
||||
f=[1, 3],
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class SchedulerBaseTests(unittest.TestCase):
|
||||
def test_save_load_from_different_config(self):
|
||||
obj = SchedulerObject()
|
||||
|
||||
# mock add obj class to `diffusers`
|
||||
setattr(diffusers, "SchedulerObject", SchedulerObject)
|
||||
logger = logging.get_logger("diffusers.configuration_utils")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
obj.save_config(tmpdirname)
|
||||
with CaptureLogger(logger) as cap_logger_1:
|
||||
config = SchedulerObject2.load_config(tmpdirname)
|
||||
new_obj_1 = SchedulerObject2.from_config(config)
|
||||
|
||||
# now save a config parameter that is not expected
|
||||
with open(os.path.join(tmpdirname, SchedulerObject.config_name), "r") as f:
|
||||
data = json.load(f)
|
||||
data["unexpected"] = True
|
||||
|
||||
with open(os.path.join(tmpdirname, SchedulerObject.config_name), "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger_2:
|
||||
config = SchedulerObject.load_config(tmpdirname)
|
||||
new_obj_2 = SchedulerObject.from_config(config)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger_3:
|
||||
config = SchedulerObject2.load_config(tmpdirname)
|
||||
new_obj_3 = SchedulerObject2.from_config(config)
|
||||
|
||||
assert new_obj_1.__class__ == SchedulerObject2
|
||||
assert new_obj_2.__class__ == SchedulerObject
|
||||
assert new_obj_3.__class__ == SchedulerObject2
|
||||
|
||||
assert cap_logger_1.out == ""
|
||||
assert (
|
||||
cap_logger_2.out
|
||||
== "The config attributes {'unexpected': True} were passed to SchedulerObject, but are not expected and"
|
||||
" will"
|
||||
" be ignored. Please verify your config.json configuration file.\n"
|
||||
)
|
||||
assert cap_logger_2.out.replace("SchedulerObject", "SchedulerObject2") == cap_logger_3.out
|
||||
|
||||
def test_save_load_compatible_schedulers(self):
|
||||
SchedulerObject2._compatibles = ["SchedulerObject"]
|
||||
SchedulerObject._compatibles = ["SchedulerObject2"]
|
||||
|
||||
obj = SchedulerObject()
|
||||
|
||||
# mock add obj class to `diffusers`
|
||||
setattr(diffusers, "SchedulerObject", SchedulerObject)
|
||||
setattr(diffusers, "SchedulerObject2", SchedulerObject2)
|
||||
logger = logging.get_logger("diffusers.configuration_utils")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
obj.save_config(tmpdirname)
|
||||
|
||||
# now save a config parameter that is expected by another class, but not origin class
|
||||
with open(os.path.join(tmpdirname, SchedulerObject.config_name), "r") as f:
|
||||
data = json.load(f)
|
||||
data["f"] = [0, 0]
|
||||
data["unexpected"] = True
|
||||
|
||||
with open(os.path.join(tmpdirname, SchedulerObject.config_name), "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
config = SchedulerObject.load_config(tmpdirname)
|
||||
new_obj = SchedulerObject.from_config(config)
|
||||
|
||||
assert new_obj.__class__ == SchedulerObject
|
||||
|
||||
assert (
|
||||
cap_logger.out
|
||||
== "The config attributes {'unexpected': True} were passed to SchedulerObject, but are not expected and"
|
||||
" will"
|
||||
" be ignored. Please verify your config.json configuration file.\n"
|
||||
)
|
||||
|
||||
def test_save_load_from_different_config_comp_schedulers(self):
|
||||
SchedulerObject3._compatibles = ["SchedulerObject", "SchedulerObject2"]
|
||||
SchedulerObject2._compatibles = ["SchedulerObject", "SchedulerObject3"]
|
||||
SchedulerObject._compatibles = ["SchedulerObject2", "SchedulerObject3"]
|
||||
|
||||
obj = SchedulerObject()
|
||||
|
||||
# mock add obj class to `diffusers`
|
||||
setattr(diffusers, "SchedulerObject", SchedulerObject)
|
||||
setattr(diffusers, "SchedulerObject2", SchedulerObject2)
|
||||
setattr(diffusers, "SchedulerObject3", SchedulerObject3)
|
||||
logger = logging.get_logger("diffusers.configuration_utils")
|
||||
logger.setLevel(diffusers.logging.INFO)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
obj.save_config(tmpdirname)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger_1:
|
||||
config = SchedulerObject.load_config(tmpdirname)
|
||||
new_obj_1 = SchedulerObject.from_config(config)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger_2:
|
||||
config = SchedulerObject2.load_config(tmpdirname)
|
||||
new_obj_2 = SchedulerObject2.from_config(config)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger_3:
|
||||
config = SchedulerObject3.load_config(tmpdirname)
|
||||
new_obj_3 = SchedulerObject3.from_config(config)
|
||||
|
||||
assert new_obj_1.__class__ == SchedulerObject
|
||||
assert new_obj_2.__class__ == SchedulerObject2
|
||||
assert new_obj_3.__class__ == SchedulerObject3
|
||||
|
||||
assert cap_logger_1.out == ""
|
||||
assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
|
||||
assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
|
||||
|
||||
|
||||
class SchedulerCommonTest(unittest.TestCase):
|
||||
scheduler_classes = ()
|
||||
forward_default_kwargs = ()
|
||||
@@ -102,7 +272,7 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
@@ -145,7 +315,7 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
@@ -187,7 +357,7 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
@@ -205,6 +375,42 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_compatibles(self):
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
assert all(c is not None for c in scheduler.compatibles)
|
||||
|
||||
for comp_scheduler_cls in scheduler.compatibles:
|
||||
comp_scheduler = comp_scheduler_cls.from_config(scheduler.config)
|
||||
assert comp_scheduler is not None
|
||||
|
||||
new_scheduler = scheduler_class.from_config(comp_scheduler.config)
|
||||
|
||||
new_scheduler_config = {k: v for k, v in new_scheduler.config.items() if k in scheduler.config}
|
||||
scheduler_diff = {k: v for k, v in new_scheduler.config.items() if k not in scheduler.config}
|
||||
|
||||
# make sure that configs are essentially identical
|
||||
assert new_scheduler_config == dict(scheduler.config)
|
||||
|
||||
# make sure that only differences are for configs that are not in init
|
||||
init_keys = inspect.signature(scheduler_class.__init__).parameters.keys()
|
||||
assert set(scheduler_diff.keys()).intersection(set(init_keys)) == set()
|
||||
|
||||
def test_from_pretrained(self):
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_pretrained(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
assert scheduler.config == new_scheduler.config
|
||||
|
||||
def test_step_shape(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
@@ -616,7 +822,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
|
||||
@@ -648,7 +854,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
@@ -790,7 +996,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.ets = dummy_past_residuals[:]
|
||||
@@ -825,7 +1031,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
@@ -1043,7 +1249,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
output = scheduler.step_pred(
|
||||
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
|
||||
@@ -1074,7 +1280,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
output = scheduler.step_pred(
|
||||
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
|
||||
@@ -1470,7 +1676,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.ets = dummy_past_residuals[:]
|
||||
@@ -1508,7 +1714,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
|
||||
@@ -83,7 +83,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
@@ -112,7 +112,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
@@ -140,7 +140,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
@@ -373,7 +373,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
@@ -401,7 +401,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
@@ -430,7 +430,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
@@ -633,7 +633,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
|
||||
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)
|
||||
# copy over dummy past residuals
|
||||
new_state = new_state.replace(ets=dummy_past_residuals[:])
|
||||
@@ -720,7 +720,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
|
||||
# copy over dummy past residuals
|
||||
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user