mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Merge branch 'main' of https://github.com/huggingface/diffusers into main
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -163,4 +163,6 @@ tags
|
||||
*.lock
|
||||
|
||||
# DS_Store (MacOS)
|
||||
.DS_Store
|
||||
.DS_Store
|
||||
# RL pipelines may produce mp4 outputs
|
||||
*.mp4
|
||||
10
README.md
10
README.md
@@ -152,15 +152,7 @@ it before the pipeline and pass it to `from_pretrained`.
|
||||
```python
|
||||
from diffusers import LMSDiscreteScheduler
|
||||
|
||||
lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
scheduler=lms,
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
image = pipe(prompt).images[0]
|
||||
|
||||
@@ -10,6 +10,8 @@
|
||||
- sections:
|
||||
- local: using-diffusers/loading
|
||||
title: "Loading Pipelines, Models, and Schedulers"
|
||||
- local: using-diffusers/schedulers
|
||||
title: "Using different Schedulers"
|
||||
- local: using-diffusers/configuration
|
||||
title: "Configuring Pipelines, Models, and Schedulers"
|
||||
- local: using-diffusers/custom_pipeline_overview
|
||||
|
||||
@@ -15,9 +15,9 @@ specific language governing permissions and limitations under the License.
|
||||
In Diffusers, schedulers of type [`schedulers.scheduling_utils.SchedulerMixin`], and models of type [`ModelMixin`] inherit from [`ConfigMixin`] which conveniently takes care of storing all parameters that are
|
||||
passed to the respective `__init__` methods in a JSON-configuration file.
|
||||
|
||||
TODO(PVP) - add example and better info here
|
||||
|
||||
## ConfigMixin
|
||||
|
||||
[[autodoc]] ConfigMixin
|
||||
- load_config
|
||||
- from_config
|
||||
- save_config
|
||||
|
||||
@@ -22,12 +22,15 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
|
||||
## UNet2DOutput
|
||||
[[autodoc]] models.unet_2d.UNet2DOutput
|
||||
|
||||
## UNet1DModel
|
||||
[[autodoc]] UNet1DModel
|
||||
|
||||
## UNet2DModel
|
||||
[[autodoc]] UNet2DModel
|
||||
|
||||
## UNet1DOutput
|
||||
[[autodoc]] models.unet_1d.UNet1DOutput
|
||||
|
||||
## UNet1DModel
|
||||
[[autodoc]] UNet1DModel
|
||||
|
||||
## UNet2DConditionOutput
|
||||
[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ from diffusers import CycleDiffusionPipeline, DDIMScheduler
|
||||
# load the pipeline
|
||||
# make sure you're logged in with `huggingface-cli login`
|
||||
model_id_or_path = "CompVis/stable-diffusion-v1-4"
|
||||
scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler")
|
||||
scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler")
|
||||
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")
|
||||
|
||||
# let's download an initial image
|
||||
|
||||
@@ -54,7 +54,7 @@ original_image = download_image(img_url).resize((256, 256))
|
||||
mask_image = download_image(mask_url).resize((256, 256))
|
||||
|
||||
# Load the RePaint scheduler and pipeline based on a pretrained DDPM model
|
||||
scheduler = RePaintScheduler.from_config("google/ddpm-ema-celebahq-256")
|
||||
scheduler = RePaintScheduler.from_pretrained("google/ddpm-ema-celebahq-256")
|
||||
pipe = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=scheduler)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
|
||||
@@ -34,13 +34,17 @@ For more details about how Stable Diffusion works and how it differs from the ba
|
||||
### How to load and use different schedulers.
|
||||
|
||||
The stable diffusion pipeline uses [`PNDMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
|
||||
To use a different scheduler, you can pass the `scheduler` argument to `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
|
||||
To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
|
||||
>>> from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
|
||||
|
||||
euler_scheduler = EulerDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=euler_scheduler)
|
||||
>>> pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
|
||||
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
>>> # or
|
||||
>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||
>>> pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=euler_scheduler)
|
||||
```
|
||||
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ In this guide though, you'll use [`DiffusionPipeline`] for text-to-image generat
|
||||
```python
|
||||
>>> from diffusers import DiffusionPipeline
|
||||
|
||||
>>> generator = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
||||
>>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
||||
```
|
||||
|
||||
The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components.
|
||||
@@ -49,13 +49,13 @@ Because the model consists of roughly 1.4 billion parameters, we strongly recomm
|
||||
You can move the generator object to GPU, just like you would in PyTorch.
|
||||
|
||||
```python
|
||||
>>> generator.to("cuda")
|
||||
>>> pipeline.to("cuda")
|
||||
```
|
||||
|
||||
Now you can use the `generator` on your text prompt:
|
||||
Now you can use the `pipeline` on your text prompt:
|
||||
|
||||
```python
|
||||
>>> image = generator("An image of a squirrel in Picasso style").images[0]
|
||||
>>> image = pipeline("An image of a squirrel in Picasso style").images[0]
|
||||
```
|
||||
|
||||
The output is by default wrapped into a [PIL Image object](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class).
|
||||
@@ -82,7 +82,7 @@ just like we did before only that now you need to pass your `AUTH_TOKEN`:
|
||||
```python
|
||||
>>> from diffusers import DiffusionPipeline
|
||||
|
||||
>>> generator = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN)
|
||||
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN)
|
||||
```
|
||||
|
||||
If you do not pass your authentication token you will see that the diffusion system will not be correctly
|
||||
@@ -102,7 +102,7 @@ token. Assuming that `"./stable-diffusion-v1-5"` is the local path to the cloned
|
||||
you can also load the pipeline as follows:
|
||||
|
||||
```python
|
||||
>>> generator = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5")
|
||||
>>> pipeline = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5")
|
||||
```
|
||||
|
||||
Running the pipeline is then identical to the code above as it's the same model architecture.
|
||||
@@ -115,19 +115,20 @@ Running the pipeline is then identical to the code above as it's the same model
|
||||
|
||||
Diffusion systems can be used with multiple different [schedulers](./api/schedulers) each with their
|
||||
pros and cons. By default, Stable Diffusion runs with [`PNDMScheduler`], but it's very simple to
|
||||
use a different scheduler. *E.g.* if you would instead like to use the [`LMSDiscreteScheduler`] scheduler,
|
||||
use a different scheduler. *E.g.* if you would instead like to use the [`EulerDiscreteScheduler`] scheduler,
|
||||
you could use it as follows:
|
||||
|
||||
```python
|
||||
>>> from diffusers import LMSDiscreteScheduler
|
||||
>>> from diffusers import EulerDiscreteScheduler
|
||||
|
||||
>>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
|
||||
>>> pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN)
|
||||
|
||||
>>> generator = StableDiffusionPipeline.from_pretrained(
|
||||
... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, use_auth_token=AUTH_TOKEN
|
||||
... )
|
||||
>>> # change scheduler to Euler
|
||||
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
```
|
||||
|
||||
For more in-detail information on how to change between schedulers, please refer to the [Using Schedulers](./using-diffusers/schedulers) guide.
|
||||
|
||||
[Stability AI's](https://stability.ai/) Stable Diffusion model is an impressive image generation model
|
||||
and can do much more than just generating images from text. We have dedicated a whole documentation page,
|
||||
just for Stable Diffusion [here](./conceptual/stable_diffusion).
|
||||
|
||||
@@ -19,7 +19,7 @@ In the following we explain in-detail how to easily load:
|
||||
|
||||
- *Complete Diffusion Pipelines* via the [`DiffusionPipeline.from_pretrained`]
|
||||
- *Diffusion Models* via [`ModelMixin.from_pretrained`]
|
||||
- *Schedulers* via [`ConfigMixin.from_config`]
|
||||
- *Schedulers* via [`SchedulerMixin.from_pretrained`]
|
||||
|
||||
## Loading pipelines
|
||||
|
||||
@@ -137,15 +137,15 @@ from diffusers import DiffusionPipeline, EulerDiscreteScheduler, DPMSolverMultis
|
||||
|
||||
repo_id = "runwayml/stable-diffusion-v1-5"
|
||||
|
||||
scheduler = EulerDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
scheduler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
|
||||
# or
|
||||
# scheduler = DPMSolverMultistepScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
# scheduler = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler")
|
||||
|
||||
stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, scheduler=scheduler)
|
||||
```
|
||||
|
||||
Three things are worth paying attention to here.
|
||||
- First, the scheduler is loaded with [`ConfigMixin.from_config`] since it only depends on a configuration file and not any parameterized weights
|
||||
- First, the scheduler is loaded with [`SchedulerMixin.from_pretrained`]
|
||||
- Second, the scheduler is loaded with a function argument, called `subfolder="scheduler"` as the configuration of stable diffusion's scheduling is defined in a [subfolder of the official pipeline repository](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/scheduler)
|
||||
- Third, the scheduler instance can simply be passed with the `scheduler` keyword argument to [`DiffusionPipeline.from_pretrained`]. This works because the [`StableDiffusionPipeline`] defines its scheduler with the `scheduler` attribute. It's not possible to use a different name, such as `sampler=scheduler` since `sampler` is not a defined keyword for [`StableDiffusionPipeline.__init__`]
|
||||
|
||||
@@ -337,8 +337,8 @@ model = UNet2DModel.from_pretrained(repo_id)
|
||||
|
||||
## Loading schedulers
|
||||
|
||||
Schedulers cannot be loaded via a `from_pretrained` method, but instead rely on [`ConfigMixin.from_config`]. Schedulers are **not parameterized** or **trained**, but instead purely defined by a configuration file.
|
||||
Therefore the loading method was given a different name here.
|
||||
Schedulers rely on [`SchedulerMixin.from_pretrained`]. Schedulers are **not parameterized** or **trained**, but instead purely defined by a configuration file.
|
||||
For consistency, we use the same method name as we do for models or pipelines, but no weights are loaded in this case.
|
||||
|
||||
In constrast to pipelines or models, loading schedulers does not consume any significant amount of memory and the same configuration file can often be used for a variety of different schedulers.
|
||||
For example, all of:
|
||||
@@ -367,13 +367,13 @@ from diffusers import (
|
||||
|
||||
repo_id = "runwayml/stable-diffusion-v1-5"
|
||||
|
||||
ddpm = DDPMScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
ddim = DDIMScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
pndm = PNDMScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
lms = LMSDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
euler_anc = EulerAncestralDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
euler = EulerDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
dpm = DPMSolverMultistepScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
ddpm = DDPMScheduler.from_pretrained(repo_id, subfolder="scheduler")
|
||||
ddim = DDIMScheduler.from_pretrained(repo_id, subfolder="scheduler")
|
||||
pndm = PNDMScheduler.from_pretrained(repo_id, subfolder="scheduler")
|
||||
lms = LMSDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
|
||||
euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
|
||||
euler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
|
||||
dpm = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler")
|
||||
|
||||
# replace `dpm` with any of `ddpm`, `ddim`, `pndm`, `lms`, `euler`, `euler_anc`
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(repo_id, scheduler=dpm)
|
||||
|
||||
262
docs/source/using-diffusers/schedulers.mdx
Normal file
262
docs/source/using-diffusers/schedulers.mdx
Normal file
@@ -0,0 +1,262 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Schedulers
|
||||
|
||||
Diffusion pipelines are inherently a collection of diffusion models and schedulers that are partly independent from each other. This means that one is able to switch out parts of the pipeline to better customize
|
||||
a pipeline to one's use case. The best example of this are the [Schedulers](../api/schedulers.mdx).
|
||||
|
||||
Whereas diffusion models usually simply define the forward pass from noise to a less noisy sample,
|
||||
schedulers define the whole denoising process, *i.e.*:
|
||||
- How many denoising steps?
|
||||
- Stochastic or deterministic?
|
||||
- What algorithm to use to find the denoised sample
|
||||
|
||||
They can be quite complex and often define a trade-off between **denoising speed** and **denoising quality**.
|
||||
It is extremely difficult to measure quantitatively which scheduler works best for a given diffusion pipeline, so it is often recommended to simply try out which works best.
|
||||
|
||||
The following paragraphs shows how to do so with the 🧨 Diffusers library.
|
||||
|
||||
## Load pipeline
|
||||
|
||||
Let's start by loading the stable diffusion pipeline.
|
||||
Remember that you have to be a registered user on the 🤗 Hugging Face Hub, and have "click-accepted" the [license](https://huggingface.co/runwayml/stable-diffusion-v1-5) in order to use stable diffusion.
|
||||
|
||||
```python
|
||||
from huggingface_hub import login
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
# first we need to login with our access token
|
||||
login()
|
||||
|
||||
# Now we can download the pipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
Next, we move it to GPU:
|
||||
|
||||
```python
|
||||
pipeline.to("cuda")
|
||||
```
|
||||
|
||||
## Access the scheduler
|
||||
|
||||
The scheduler is always one of the components of the pipeline and is usually called `"scheduler"`.
|
||||
So it can be accessed via the `"scheduler"` property.
|
||||
|
||||
```python
|
||||
pipeline.scheduler
|
||||
```
|
||||
|
||||
**Output**:
|
||||
```
|
||||
PNDMScheduler {
|
||||
"_class_name": "PNDMScheduler",
|
||||
"_diffusers_version": "0.8.0.dev0",
|
||||
"beta_end": 0.012,
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
"clip_sample": false,
|
||||
"num_train_timesteps": 1000,
|
||||
"set_alpha_to_one": false,
|
||||
"skip_prk_steps": true,
|
||||
"steps_offset": 1,
|
||||
"trained_betas": null
|
||||
}
|
||||
```
|
||||
|
||||
We can see that the scheduler is of type [`PNDMScheduler`].
|
||||
Cool, now let's compare the scheduler in its performance to other schedulers.
|
||||
First we define a prompt on which we will test all the different schedulers:
|
||||
|
||||
```python
|
||||
prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition."
|
||||
```
|
||||
|
||||
Next, we create a generator from a random seed that will ensure that we can generate similar images as well as run the pipeline:
|
||||
|
||||
```python
|
||||
generator = torch.Generator(device="cuda").manual_seed(8)
|
||||
image = pipeline(prompt, generator=generator).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
<br>
|
||||
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_pndm.png" width="400"/>
|
||||
<br>
|
||||
</p>
|
||||
|
||||
|
||||
## Changing the scheduler
|
||||
|
||||
Now we show how easy it is to change the scheduler of a pipeline. Every scheduler has a property [`SchedulerMixin.compatibles`]
|
||||
which defines all compatible schedulers. You can take a look at all available, compatible schedulers for the Stable Diffusion pipeline as follows.
|
||||
|
||||
```python
|
||||
pipeline.scheduler.compatibles
|
||||
```
|
||||
|
||||
**Output**:
|
||||
```
|
||||
[diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,
|
||||
diffusers.schedulers.scheduling_ddim.DDIMScheduler,
|
||||
diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,
|
||||
diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,
|
||||
diffusers.schedulers.scheduling_pndm.PNDMScheduler,
|
||||
diffusers.schedulers.scheduling_ddpm.DDPMScheduler,
|
||||
diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler]
|
||||
```
|
||||
|
||||
Cool, lots of schedulers to look at. Feel free to have a look at their respective class definitions:
|
||||
|
||||
- [`LMSDiscreteScheduler`],
|
||||
- [`DDIMScheduler`],
|
||||
- [`DPMSolverMultistepScheduler`],
|
||||
- [`EulerDiscreteScheduler`],
|
||||
- [`PNDMScheduler`],
|
||||
- [`DDPMScheduler`],
|
||||
- [`EulerAncestralDiscreteScheduler`].
|
||||
|
||||
We will now compare the input prompt with all other schedulers. To change the scheduler of the pipeline you can make use of the
|
||||
convenient [`ConfigMixin.config`] property in combination with the [`ConfigMixin.from_config`] function.
|
||||
|
||||
```python
|
||||
pipeline.scheduler.config
|
||||
```
|
||||
|
||||
returns a dictionary of the configuration of the scheduler:
|
||||
|
||||
**Output**:
|
||||
```
|
||||
FrozenDict([('num_train_timesteps', 1000),
|
||||
('beta_start', 0.00085),
|
||||
('beta_end', 0.012),
|
||||
('beta_schedule', 'scaled_linear'),
|
||||
('trained_betas', None),
|
||||
('skip_prk_steps', True),
|
||||
('set_alpha_to_one', False),
|
||||
('steps_offset', 1),
|
||||
('_class_name', 'PNDMScheduler'),
|
||||
('_diffusers_version', '0.8.0.dev0'),
|
||||
('clip_sample', False)])
|
||||
```
|
||||
|
||||
This configuration can then be used to instantiate a scheduler
|
||||
of a different class that is compatible with the pipeline. Here,
|
||||
we change the scheduler to the [`DDIMScheduler`].
|
||||
|
||||
```python
|
||||
from diffusers import DDIMScheduler
|
||||
|
||||
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
|
||||
```
|
||||
|
||||
Cool, now we can run the pipeline again to compare the generation quality.
|
||||
|
||||
```python
|
||||
generator = torch.Generator(device="cuda").manual_seed(8)
|
||||
image = pipeline(prompt, generator=generator).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
<br>
|
||||
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_ddim.png" width="400"/>
|
||||
<br>
|
||||
</p>
|
||||
|
||||
|
||||
## Compare schedulers
|
||||
|
||||
So far we have tried running the stable diffusion pipeline with two schedulers: [`PNDMScheduler`] and [`DDIMScheduler`].
|
||||
A number of better schedulers have been released that can be run with much fewer steps, let's compare them here:
|
||||
|
||||
[`LMSDiscreteScheduler`] usually leads to better results:
|
||||
|
||||
```python
|
||||
from diffusers import LMSDiscreteScheduler
|
||||
|
||||
pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
generator = torch.Generator(device="cuda").manual_seed(8)
|
||||
image = pipeline(prompt, generator=generator).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
<br>
|
||||
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_lms.png" width="400"/>
|
||||
<br>
|
||||
</p>
|
||||
|
||||
|
||||
[`EulerDiscreteScheduler`] and [`EulerAncestralDiscreteScheduler`] can generate high quality results with as little as 30 steps.
|
||||
|
||||
```python
|
||||
from diffusers import EulerDiscreteScheduler
|
||||
|
||||
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
generator = torch.Generator(device="cuda").manual_seed(8)
|
||||
image = pipeline(prompt, generator=generator, num_inference_steps=30).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
<br>
|
||||
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_euler_discrete.png" width="400"/>
|
||||
<br>
|
||||
</p>
|
||||
|
||||
|
||||
and:
|
||||
|
||||
```python
|
||||
from diffusers import EulerAncestralDiscreteScheduler
|
||||
|
||||
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
generator = torch.Generator(device="cuda").manual_seed(8)
|
||||
image = pipeline(prompt, generator=generator, num_inference_steps=30).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
<br>
|
||||
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_euler_ancestral.png" width="400"/>
|
||||
<br>
|
||||
</p>
|
||||
|
||||
|
||||
At the time of writing this doc [`DPMSolverMultistepScheduler`] gives arguably the best speed/quality trade-off and can be run with as little
|
||||
as 20 steps.
|
||||
|
||||
```python
|
||||
from diffusers import DPMSolverMultistepScheduler
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
generator = torch.Generator(device="cuda").manual_seed(8)
|
||||
image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
<br>
|
||||
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_dpm.png" width="400"/>
|
||||
<br>
|
||||
</p>
|
||||
|
||||
As you can see most images look very similar and are arguably of very similar quality. It often really depends on the specific use case which scheduler to choose. A good approach is always to run multiple different
|
||||
schedulers to compare results.
|
||||
@@ -42,7 +42,7 @@ Training examples show how to pretrain or fine-tune diffusion models for a varie
|
||||
| [**Text-to-Image fine-tuning**](./text_to_image) | ✅ | ✅ |
|
||||
| [**Textual Inversion**](./textual_inversion) | ✅ | - | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
|
||||
| [**Dreambooth**](./dreambooth) | ✅ | - | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb)
|
||||
|
||||
| [**Reinforcement Learning for Control**](https://github.com/huggingface/diffusers/blob/main/examples/rl/run_diffusers_locomotion.py) | - | - | coming soon.
|
||||
|
||||
## Community
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ If a community doesn't work as expected, please open an issue and ping the autho
|
||||
| Long Prompt Weighting Stable Diffusion | **One** Stable Diffusion Pipeline without tokens length limit, and support parsing weighting in prompt. | [Long Prompt Weighting Stable Diffusion](#long-prompt-weighting-stable-diffusion) | - | [SkyTNT](https://github.com/SkyTNT) |
|
||||
| Speech to Image | Using automatic-speech-recognition to transcribe text and Stable Diffusion to generate images | [Speech to Image](#speech-to-image) | - | [Mikail Duzenli](https://github.com/MikailINTech)
|
||||
| Wild Card Stable Diffusion | Stable Diffusion Pipeline that supports prompts that contain wildcard terms (indicated by surrounding double underscores), with values instantiated randomly from a corresponding txt file or a dictionary of possible values | [Wildcard Stable Diffusion](#wildcard-stable-diffusion) | - | [Shyam Sudhakaran](https://github.com/shyamsn97) |
|
||||
| Composable Stable Diffusion| Stable Diffusion Pipeline that supports prompts that contain "|" in prompts (as an AND condition) and weights (separated by "|" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
|
||||
| [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain "|" in prompts (as an AND condition) and weights (separated by "|" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
|
||||
| Seed Resizing Stable Diffusion| Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | - | [Mark Rich](https://github.com/MarkRich) |
|
||||
| Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image| [Imagic Stable Diffusion](#imagic-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
|
||||
| Multilingual Stable Diffusion| Stable Diffusion Pipeline that supports prompts in 50 different languages. | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline) | - | [Juan Carlos Piñeros](https://github.com/juancopi81) |
|
||||
@@ -345,6 +345,8 @@ out = pipe(
|
||||
|
||||
### Composable Stable diffusion
|
||||
|
||||
[Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) proposes conjunction and negation (negative prompts) operators for compositional generation with conditional diffusion models.
|
||||
|
||||
```python
|
||||
import torch as th
|
||||
import numpy as np
|
||||
|
||||
@@ -92,7 +92,7 @@ accelerate launch train_dreambooth.py \
|
||||
|
||||
With the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU.
|
||||
|
||||
Install `bitsandbytes` with `pip install bitsandbytes`
|
||||
To install `bitandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation).
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
|
||||
19
examples/rl/README.md
Normal file
19
examples/rl/README.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# Overview
|
||||
|
||||
These examples show how to run (Diffuser)[https://arxiv.org/abs/2205.09991] in Diffusers.
|
||||
There are four scripts,
|
||||
1. `run_diffuser_locomotion.py` to sample actions and run them in the environment,
|
||||
2. and `run_diffuser_gen_trajectories.py` to just sample actions from the pre-trained diffusion model.
|
||||
|
||||
You will need some RL specific requirements to run the examples:
|
||||
|
||||
```
|
||||
pip install -f https://download.pytorch.org/whl/torch_stable.html \
|
||||
free-mujoco-py \
|
||||
einops \
|
||||
gym==0.24.1 \
|
||||
protobuf==3.20.1 \
|
||||
git+https://github.com/rail-berkeley/d4rl.git \
|
||||
mediapy \
|
||||
Pillow==9.0.0
|
||||
```
|
||||
57
examples/rl/run_diffuser_gen_trajectories.py
Normal file
57
examples/rl/run_diffuser_gen_trajectories.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import d4rl # noqa
|
||||
import gym
|
||||
import tqdm
|
||||
from diffusers.experimental import ValueGuidedRLPipeline
|
||||
|
||||
|
||||
config = dict(
|
||||
n_samples=64,
|
||||
horizon=32,
|
||||
num_inference_steps=20,
|
||||
n_guide_steps=0,
|
||||
scale_grad_by_std=True,
|
||||
scale=0.1,
|
||||
eta=0.0,
|
||||
t_grad_cutoff=2,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
env_name = "hopper-medium-v2"
|
||||
env = gym.make(env_name)
|
||||
|
||||
pipeline = ValueGuidedRLPipeline.from_pretrained(
|
||||
"bglick13/hopper-medium-v2-value-function-hor32",
|
||||
env=env,
|
||||
)
|
||||
|
||||
env.seed(0)
|
||||
obs = env.reset()
|
||||
total_reward = 0
|
||||
total_score = 0
|
||||
T = 1000
|
||||
rollout = [obs.copy()]
|
||||
try:
|
||||
for t in tqdm.tqdm(range(T)):
|
||||
# Call the policy
|
||||
denorm_actions = pipeline(obs, planning_horizon=32)
|
||||
|
||||
# execute action in environment
|
||||
next_observation, reward, terminal, _ = env.step(denorm_actions)
|
||||
score = env.get_normalized_score(total_reward)
|
||||
# update return
|
||||
total_reward += reward
|
||||
total_score += score
|
||||
print(
|
||||
f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:"
|
||||
f" {total_score}"
|
||||
)
|
||||
# save observations for rendering
|
||||
rollout.append(next_observation.copy())
|
||||
|
||||
obs = next_observation
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
print(f"Total reward: {total_reward}")
|
||||
57
examples/rl/run_diffuser_locomotion.py
Normal file
57
examples/rl/run_diffuser_locomotion.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import d4rl # noqa
|
||||
import gym
|
||||
import tqdm
|
||||
from diffusers.experimental import ValueGuidedRLPipeline
|
||||
|
||||
|
||||
config = dict(
|
||||
n_samples=64,
|
||||
horizon=32,
|
||||
num_inference_steps=20,
|
||||
n_guide_steps=2,
|
||||
scale_grad_by_std=True,
|
||||
scale=0.1,
|
||||
eta=0.0,
|
||||
t_grad_cutoff=2,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
env_name = "hopper-medium-v2"
|
||||
env = gym.make(env_name)
|
||||
|
||||
pipeline = ValueGuidedRLPipeline.from_pretrained(
|
||||
"bglick13/hopper-medium-v2-value-function-hor32",
|
||||
env=env,
|
||||
)
|
||||
|
||||
env.seed(0)
|
||||
obs = env.reset()
|
||||
total_reward = 0
|
||||
total_score = 0
|
||||
T = 1000
|
||||
rollout = [obs.copy()]
|
||||
try:
|
||||
for t in tqdm.tqdm(range(T)):
|
||||
# call the policy
|
||||
denorm_actions = pipeline(obs, planning_horizon=32)
|
||||
|
||||
# execute action in environment
|
||||
next_observation, reward, terminal, _ = env.step(denorm_actions)
|
||||
score = env.get_normalized_score(total_reward)
|
||||
# update return
|
||||
total_reward += reward
|
||||
total_score += score
|
||||
print(
|
||||
f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:"
|
||||
f" {total_score}"
|
||||
)
|
||||
# save observations for rendering
|
||||
rollout.append(next_observation.copy())
|
||||
|
||||
obs = next_observation
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
print(f"Total reward: {total_reward}")
|
||||
100
scripts/convert_models_diffuser_to_diffusers.py
Normal file
100
scripts/convert_models_diffuser_to_diffusers.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import UNet1DModel
|
||||
|
||||
|
||||
os.makedirs("hub/hopper-medium-v2/unet/hor32", exist_ok=True)
|
||||
os.makedirs("hub/hopper-medium-v2/unet/hor128", exist_ok=True)
|
||||
|
||||
os.makedirs("hub/hopper-medium-v2/value_function", exist_ok=True)
|
||||
|
||||
|
||||
def unet(hor):
|
||||
if hor == 128:
|
||||
down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
|
||||
block_out_channels = (32, 128, 256)
|
||||
up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D")
|
||||
|
||||
elif hor == 32:
|
||||
down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
|
||||
block_out_channels = (32, 64, 128, 256)
|
||||
up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D")
|
||||
model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch")
|
||||
state_dict = model.state_dict()
|
||||
config = dict(
|
||||
down_block_types=down_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
up_block_types=up_block_types,
|
||||
layers_per_block=1,
|
||||
use_timestep_embedding=True,
|
||||
out_block_type="OutConv1DBlock",
|
||||
norm_num_groups=8,
|
||||
downsample_each_block=False,
|
||||
in_channels=14,
|
||||
out_channels=14,
|
||||
extra_in_channels=0,
|
||||
time_embedding_type="positional",
|
||||
flip_sin_to_cos=False,
|
||||
freq_shift=1,
|
||||
sample_size=65536,
|
||||
mid_block_type="MidResTemporalBlock1D",
|
||||
act_fn="mish",
|
||||
)
|
||||
hf_value_function = UNet1DModel(**config)
|
||||
print(f"length of state dict: {len(state_dict.keys())}")
|
||||
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
|
||||
mapping = dict((k, hfk) for k, hfk in zip(model.state_dict().keys(), hf_value_function.state_dict().keys()))
|
||||
for k, v in mapping.items():
|
||||
state_dict[v] = state_dict.pop(k)
|
||||
hf_value_function.load_state_dict(state_dict)
|
||||
|
||||
torch.save(hf_value_function.state_dict(), f"hub/hopper-medium-v2/unet/hor{hor}/diffusion_pytorch_model.bin")
|
||||
with open(f"hub/hopper-medium-v2/unet/hor{hor}/config.json", "w") as f:
|
||||
json.dump(config, f)
|
||||
|
||||
|
||||
def value_function():
|
||||
config = dict(
|
||||
in_channels=14,
|
||||
down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
|
||||
up_block_types=(),
|
||||
out_block_type="ValueFunction",
|
||||
mid_block_type="ValueFunctionMidBlock1D",
|
||||
block_out_channels=(32, 64, 128, 256),
|
||||
layers_per_block=1,
|
||||
downsample_each_block=True,
|
||||
sample_size=65536,
|
||||
out_channels=14,
|
||||
extra_in_channels=0,
|
||||
time_embedding_type="positional",
|
||||
use_timestep_embedding=True,
|
||||
flip_sin_to_cos=False,
|
||||
freq_shift=1,
|
||||
norm_num_groups=8,
|
||||
act_fn="mish",
|
||||
)
|
||||
|
||||
model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch")
|
||||
state_dict = model
|
||||
hf_value_function = UNet1DModel(**config)
|
||||
print(f"length of state dict: {len(state_dict.keys())}")
|
||||
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
|
||||
|
||||
mapping = dict((k, hfk) for k, hfk in zip(state_dict.keys(), hf_value_function.state_dict().keys()))
|
||||
for k, v in mapping.items():
|
||||
state_dict[v] = state_dict.pop(k)
|
||||
|
||||
hf_value_function.load_state_dict(state_dict)
|
||||
|
||||
torch.save(hf_value_function.state_dict(), "hub/hopper-medium-v2/value_function/diffusion_pytorch_model.bin")
|
||||
with open("hub/hopper-medium-v2/value_function/config.json", "w") as f:
|
||||
json.dump(config, f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unet(32)
|
||||
# unet(128)
|
||||
value_function()
|
||||
@@ -29,7 +29,7 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R
|
||||
from requests import HTTPError
|
||||
|
||||
from . import __version__
|
||||
from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
|
||||
from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -37,6 +37,38 @@ logger = logging.get_logger(__name__)
|
||||
_re_configuration_file = re.compile(r"config\.(.*)\.json")
|
||||
|
||||
|
||||
class FrozenDict(OrderedDict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
for key, value in self.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
self.__frozen = True
|
||||
|
||||
def __delitem__(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
def setdefault(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
def pop(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if hasattr(self, "__frozen") and self.__frozen:
|
||||
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def __setitem__(self, name, value):
|
||||
if hasattr(self, "__frozen") and self.__frozen:
|
||||
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
||||
super().__setitem__(name, value)
|
||||
|
||||
|
||||
class ConfigMixin:
|
||||
r"""
|
||||
Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
|
||||
@@ -49,13 +81,12 @@ class ConfigMixin:
|
||||
[`~ConfigMixin.save_config`] (should be overridden by parent class).
|
||||
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
|
||||
overridden by parent class).
|
||||
- **_compatible_classes** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
|
||||
`from_config` can be used from a class different than the one used to save the config (should be overridden
|
||||
by parent class).
|
||||
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by parent
|
||||
class).
|
||||
"""
|
||||
config_name = None
|
||||
ignore_for_config = []
|
||||
_compatible_classes = []
|
||||
has_compatibles = False
|
||||
|
||||
def register_to_config(self, **kwargs):
|
||||
if self.config_name is None:
|
||||
@@ -104,9 +135,98 @@ class ConfigMixin:
|
||||
logger.info(f"Configuration saved in {output_config_file}")
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
|
||||
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
|
||||
r"""
|
||||
Instantiate a Python class from a pre-defined JSON-file.
|
||||
Instantiate a Python class from a config dictionary
|
||||
|
||||
Parameters:
|
||||
config (`Dict[str, Any]`):
|
||||
A config dictionary from which the Python class will be instantiated. Make sure to only load
|
||||
configuration files of compatible classes.
|
||||
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
||||
Whether kwargs that are not consumed by the Python class should be returned or not.
|
||||
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the Python class.
|
||||
`**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually
|
||||
overwrite same named arguments of `config`.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
|
||||
|
||||
>>> # Download scheduler from huggingface.co and cache.
|
||||
>>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
|
||||
|
||||
>>> # Instantiate DDIM scheduler class with same config as DDPM
|
||||
>>> scheduler = DDIMScheduler.from_config(scheduler.config)
|
||||
|
||||
>>> # Instantiate PNDM scheduler class with same config as DDPM
|
||||
>>> scheduler = PNDMScheduler.from_config(scheduler.config)
|
||||
```
|
||||
"""
|
||||
# <===== TO BE REMOVED WITH DEPRECATION
|
||||
# TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
|
||||
if "pretrained_model_name_or_path" in kwargs:
|
||||
config = kwargs.pop("pretrained_model_name_or_path")
|
||||
|
||||
if config is None:
|
||||
raise ValueError("Please make sure to provide a config as the first positional argument.")
|
||||
# ======>
|
||||
|
||||
if not isinstance(config, dict):
|
||||
deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
|
||||
if "Scheduler" in cls.__name__:
|
||||
deprecation_message += (
|
||||
f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
|
||||
" Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
|
||||
" be removed in v1.0.0."
|
||||
)
|
||||
elif "Model" in cls.__name__:
|
||||
deprecation_message += (
|
||||
f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
|
||||
f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
|
||||
" instead. This functionality will be removed in v1.0.0."
|
||||
)
|
||||
deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
|
||||
config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
|
||||
|
||||
init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
|
||||
|
||||
# Allow dtype to be specified on initialization
|
||||
if "dtype" in unused_kwargs:
|
||||
init_dict["dtype"] = unused_kwargs.pop("dtype")
|
||||
|
||||
# Return model and optionally state and/or unused_kwargs
|
||||
model = cls(**init_dict)
|
||||
|
||||
# make sure to also save config parameters that might be used for compatible classes
|
||||
model.register_to_config(**hidden_dict)
|
||||
|
||||
# add hidden kwargs of compatible classes to unused_kwargs
|
||||
unused_kwargs = {**unused_kwargs, **hidden_dict}
|
||||
|
||||
if return_unused_kwargs:
|
||||
return (model, unused_kwargs)
|
||||
else:
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def get_config_dict(cls, *args, **kwargs):
|
||||
deprecation_message = (
|
||||
f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
|
||||
" removed in version v1.0.0"
|
||||
)
|
||||
deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
|
||||
return cls.load_config(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def load_config(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
r"""
|
||||
Instantiate a Python class from a config dictionary
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
||||
@@ -120,10 +240,6 @@ class ConfigMixin:
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
|
||||
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
|
||||
checkpoint with 3 labels).
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
@@ -161,33 +277,7 @@ class ConfigMixin:
|
||||
use this method in a firewalled environment.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
|
||||
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
# Allow dtype to be specified on initialization
|
||||
if "dtype" in unused_kwargs:
|
||||
init_dict["dtype"] = unused_kwargs.pop("dtype")
|
||||
|
||||
# Return model and optionally state and/or unused_kwargs
|
||||
model = cls(**init_dict)
|
||||
return_tuple = (model,)
|
||||
|
||||
# Flax schedulers have a state, so return it.
|
||||
if cls.__name__.startswith("Flax") and hasattr(model, "create_state") and getattr(model, "has_state", False):
|
||||
state = model.create_state()
|
||||
return_tuple += (state,)
|
||||
|
||||
if return_unused_kwargs:
|
||||
return return_tuple + (unused_kwargs,)
|
||||
else:
|
||||
return return_tuple if len(return_tuple) > 1 else model
|
||||
|
||||
@classmethod
|
||||
def get_config_dict(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
@@ -283,6 +373,9 @@ class ConfigMixin:
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
|
||||
|
||||
if return_unused_kwargs:
|
||||
return config_dict, kwargs
|
||||
|
||||
return config_dict
|
||||
|
||||
@staticmethod
|
||||
@@ -291,6 +384,9 @@ class ConfigMixin:
|
||||
|
||||
@classmethod
|
||||
def extract_init_dict(cls, config_dict, **kwargs):
|
||||
# 0. Copy origin config dict
|
||||
original_dict = {k: v for k, v in config_dict.items()}
|
||||
|
||||
# 1. Retrieve expected config attributes from __init__ signature
|
||||
expected_keys = cls._get_init_keys(cls)
|
||||
expected_keys.remove("self")
|
||||
@@ -310,10 +406,11 @@ class ConfigMixin:
|
||||
# load diffusers library to import compatible and original scheduler
|
||||
diffusers_library = importlib.import_module(__name__.split(".")[0])
|
||||
|
||||
# remove attributes from compatible classes that orig cannot expect
|
||||
compatible_classes = [getattr(diffusers_library, c, None) for c in cls._compatible_classes]
|
||||
# filter out None potentially undefined dummy classes
|
||||
compatible_classes = [c for c in compatible_classes if c is not None]
|
||||
if cls.has_compatibles:
|
||||
compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
|
||||
else:
|
||||
compatible_classes = []
|
||||
|
||||
expected_keys_comp_cls = set()
|
||||
for c in compatible_classes:
|
||||
expected_keys_c = cls._get_init_keys(c)
|
||||
@@ -364,7 +461,10 @@ class ConfigMixin:
|
||||
# 6. Define unused keyword arguments
|
||||
unused_kwargs = {**config_dict, **kwargs}
|
||||
|
||||
return init_dict, unused_kwargs
|
||||
# 7. Define "hidden" config parameters that were saved for compatible classes
|
||||
hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict and not k.startswith("_")}
|
||||
|
||||
return init_dict, unused_kwargs, hidden_config_dict
|
||||
|
||||
@classmethod
|
||||
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
||||
@@ -377,6 +477,12 @@ class ConfigMixin:
|
||||
|
||||
@property
|
||||
def config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns the config of the class as a frozen dictionary
|
||||
|
||||
Returns:
|
||||
`Dict[str, Any]`: Config of the class.
|
||||
"""
|
||||
return self._internal_dict
|
||||
|
||||
def to_json_string(self) -> str:
|
||||
@@ -401,38 +507,6 @@ class ConfigMixin:
|
||||
writer.write(self.to_json_string())
|
||||
|
||||
|
||||
class FrozenDict(OrderedDict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
for key, value in self.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
self.__frozen = True
|
||||
|
||||
def __delitem__(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
def setdefault(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
def pop(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if hasattr(self, "__frozen") and self.__frozen:
|
||||
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def __setitem__(self, name, value):
|
||||
if hasattr(self, "__frozen") and self.__frozen:
|
||||
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
||||
super().__setitem__(name, value)
|
||||
|
||||
|
||||
def register_to_config(init):
|
||||
r"""
|
||||
Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
|
||||
|
||||
5
src/diffusers/experimental/README.md
Normal file
5
src/diffusers/experimental/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# 🧨 Diffusers Experimental
|
||||
|
||||
We are adding experimental code to support novel applications and usages of the Diffusers library.
|
||||
Currently, the following experiments are supported:
|
||||
* Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model.
|
||||
1
src/diffusers/experimental/__init__.py
Normal file
1
src/diffusers/experimental/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .rl import ValueGuidedRLPipeline
|
||||
1
src/diffusers/experimental/rl/__init__.py
Normal file
1
src/diffusers/experimental/rl/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .value_guided_sampling import ValueGuidedRLPipeline
|
||||
129
src/diffusers/experimental/rl/value_guided_sampling.py
Normal file
129
src/diffusers/experimental/rl/value_guided_sampling.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import tqdm
|
||||
|
||||
from ...models.unet_1d import UNet1DModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...utils.dummy_pt_objects import DDPMScheduler
|
||||
|
||||
|
||||
class ValueGuidedRLPipeline(DiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
value_function: UNet1DModel,
|
||||
unet: UNet1DModel,
|
||||
scheduler: DDPMScheduler,
|
||||
env,
|
||||
):
|
||||
super().__init__()
|
||||
self.value_function = value_function
|
||||
self.unet = unet
|
||||
self.scheduler = scheduler
|
||||
self.env = env
|
||||
self.data = env.get_dataset()
|
||||
self.means = dict()
|
||||
for key in self.data.keys():
|
||||
try:
|
||||
self.means[key] = self.data[key].mean()
|
||||
except:
|
||||
pass
|
||||
self.stds = dict()
|
||||
for key in self.data.keys():
|
||||
try:
|
||||
self.stds[key] = self.data[key].std()
|
||||
except:
|
||||
pass
|
||||
self.state_dim = env.observation_space.shape[0]
|
||||
self.action_dim = env.action_space.shape[0]
|
||||
|
||||
def normalize(self, x_in, key):
|
||||
return (x_in - self.means[key]) / self.stds[key]
|
||||
|
||||
def de_normalize(self, x_in, key):
|
||||
return x_in * self.stds[key] + self.means[key]
|
||||
|
||||
def to_torch(self, x_in):
|
||||
if type(x_in) is dict:
|
||||
return {k: self.to_torch(v) for k, v in x_in.items()}
|
||||
elif torch.is_tensor(x_in):
|
||||
return x_in.to(self.unet.device)
|
||||
return torch.tensor(x_in, device=self.unet.device)
|
||||
|
||||
def reset_x0(self, x_in, cond, act_dim):
|
||||
for key, val in cond.items():
|
||||
x_in[:, key, act_dim:] = val.clone()
|
||||
return x_in
|
||||
|
||||
def run_diffusion(self, x, conditions, n_guide_steps, scale):
|
||||
batch_size = x.shape[0]
|
||||
y = None
|
||||
for i in tqdm.tqdm(self.scheduler.timesteps):
|
||||
# create batch of timesteps to pass into model
|
||||
timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
|
||||
for _ in range(n_guide_steps):
|
||||
with torch.enable_grad():
|
||||
x.requires_grad_()
|
||||
y = self.value_function(x.permute(0, 2, 1), timesteps).sample
|
||||
grad = torch.autograd.grad([y.sum()], [x])[0]
|
||||
|
||||
posterior_variance = self.scheduler._get_variance(i)
|
||||
model_std = torch.exp(0.5 * posterior_variance)
|
||||
grad = model_std * grad
|
||||
grad[timesteps < 2] = 0
|
||||
x = x.detach()
|
||||
x = x + scale * grad
|
||||
x = self.reset_x0(x, conditions, self.action_dim)
|
||||
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
|
||||
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
|
||||
|
||||
# apply conditions to the trajectory
|
||||
x = self.reset_x0(x, conditions, self.action_dim)
|
||||
x = self.to_torch(x)
|
||||
return x, y
|
||||
|
||||
def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
|
||||
# normalize the observations and create batch dimension
|
||||
obs = self.normalize(obs, "observations")
|
||||
obs = obs[None].repeat(batch_size, axis=0)
|
||||
|
||||
conditions = {0: self.to_torch(obs)}
|
||||
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
|
||||
|
||||
# generate initial noise and apply our conditions (to make the trajectories start at current state)
|
||||
x1 = torch.randn(shape, device=self.unet.device)
|
||||
x = self.reset_x0(x1, conditions, self.action_dim)
|
||||
x = self.to_torch(x)
|
||||
|
||||
# run the diffusion process
|
||||
x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
|
||||
|
||||
# sort output trajectories by value
|
||||
sorted_idx = y.argsort(0, descending=True).squeeze()
|
||||
sorted_values = x[sorted_idx]
|
||||
actions = sorted_values[:, :, : self.action_dim]
|
||||
actions = actions.detach().cpu().numpy()
|
||||
denorm_actions = self.de_normalize(actions, key="actions")
|
||||
|
||||
# select the action with the highest value
|
||||
if y is not None:
|
||||
selected_index = 0
|
||||
else:
|
||||
# if we didn't run value guiding, select a random action
|
||||
selected_index = np.random.randint(0, batch_size)
|
||||
denorm_actions = denorm_actions[selected_index, 0]
|
||||
return denorm_actions
|
||||
@@ -62,14 +62,21 @@ def get_timestep_embedding(
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"):
|
||||
def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(channel, time_embed_dim)
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||
self.act = None
|
||||
if act_fn == "silu":
|
||||
self.act = nn.SiLU()
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
||||
elif act_fn == "mish":
|
||||
self.act = nn.Mish()
|
||||
|
||||
if out_dim is not None:
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
||||
|
||||
def forward(self, sample):
|
||||
sample = self.linear_1(sample)
|
||||
|
||||
@@ -5,6 +5,75 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Upsample1D(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels: channels in the inputs and outputs.
|
||||
use_conv: a bool determining if a convolution is applied.
|
||||
use_conv_transpose:
|
||||
out_channels:
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
|
||||
self.conv = None
|
||||
if use_conv_transpose:
|
||||
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
||||
elif use_conv:
|
||||
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(x)
|
||||
|
||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Downsample1D(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels: channels in the inputs and outputs.
|
||||
use_conv: a bool determining if a convolution is applied.
|
||||
out_channels:
|
||||
padding:
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.padding = padding
|
||||
stride = 2
|
||||
self.name = name
|
||||
|
||||
if use_conv:
|
||||
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Upsample2D(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
@@ -12,7 +81,8 @@ class Upsample2D(nn.Module):
|
||||
Parameters:
|
||||
channels: channels in the inputs and outputs.
|
||||
use_conv: a bool determining if a convolution is applied.
|
||||
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions.
|
||||
use_conv_transpose:
|
||||
out_channels:
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
||||
@@ -80,7 +150,8 @@ class Downsample2D(nn.Module):
|
||||
Parameters:
|
||||
channels: channels in the inputs and outputs.
|
||||
use_conv: a bool determining if a convolution is applied.
|
||||
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions.
|
||||
out_channels:
|
||||
padding:
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
||||
@@ -415,6 +486,69 @@ class Mish(torch.nn.Module):
|
||||
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
|
||||
|
||||
|
||||
# unet_rl.py
|
||||
def rearrange_dims(tensor):
|
||||
if len(tensor.shape) == 2:
|
||||
return tensor[:, :, None]
|
||||
if len(tensor.shape) == 3:
|
||||
return tensor[:, :, None, :]
|
||||
elif len(tensor.shape) == 4:
|
||||
return tensor[:, :, 0, :]
|
||||
else:
|
||||
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
|
||||
|
||||
|
||||
class Conv1dBlock(nn.Module):
|
||||
"""
|
||||
Conv1d --> GroupNorm --> Mish
|
||||
"""
|
||||
|
||||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||
super().__init__()
|
||||
|
||||
self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.group_norm = nn.GroupNorm(n_groups, out_channels)
|
||||
self.mish = nn.Mish()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1d(x)
|
||||
x = rearrange_dims(x)
|
||||
x = self.group_norm(x)
|
||||
x = rearrange_dims(x)
|
||||
x = self.mish(x)
|
||||
return x
|
||||
|
||||
|
||||
# unet_rl.py
|
||||
class ResidualTemporalBlock1D(nn.Module):
|
||||
def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
|
||||
super().__init__()
|
||||
self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
|
||||
self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
|
||||
|
||||
self.time_emb_act = nn.Mish()
|
||||
self.time_emb = nn.Linear(embed_dim, out_channels)
|
||||
|
||||
self.residual_conv = (
|
||||
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x, t):
|
||||
"""
|
||||
Args:
|
||||
x : [ batch_size x inp_channels x horizon ]
|
||||
t : [ batch_size x embed_dim ]
|
||||
|
||||
returns:
|
||||
out : [ batch_size x out_channels x horizon ]
|
||||
"""
|
||||
t = self.time_emb_act(t)
|
||||
t = self.time_emb(t)
|
||||
out = self.conv_in(x) + rearrange_dims(t)
|
||||
out = self.conv_out(out)
|
||||
return out + self.residual_conv(x)
|
||||
|
||||
|
||||
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
||||
r"""Upsample2D a batch of 2D images with the given filter.
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
@@ -8,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .unet_1d_blocks import get_down_block, get_mid_block, get_up_block
|
||||
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -30,11 +44,11 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
implements for all the model (such as downloading or saving, etc.)
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`, *optionl*): Default length of sample. Should be adaptable at runtime.
|
||||
sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
|
||||
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
|
||||
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
|
||||
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
|
||||
freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for fourier time embedding.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to :
|
||||
obj:`False`): Whether to flip sin to cos for fourier time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to :
|
||||
@@ -43,6 +57,13 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to :
|
||||
obj:`(32, 32, 64)`): Tuple of block output channels.
|
||||
mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet.
|
||||
out_block_type (`str`, *optional*, defaults to `None`): optional output processing of UNet.
|
||||
act_fn (`str`, *optional*, defaults to None): optional activitation function in UNet blocks.
|
||||
norm_num_groups (`int`, *optional*, defaults to 8): group norm member count in UNet blocks.
|
||||
layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block.
|
||||
downsample_each_block (`int`, *optional*, defaults to False:
|
||||
experimental feature for using a UNet without upsampling.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
@@ -54,16 +75,20 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
out_channels: int = 2,
|
||||
extra_in_channels: int = 0,
|
||||
time_embedding_type: str = "fourier",
|
||||
freq_shift: int = 0,
|
||||
flip_sin_to_cos: bool = True,
|
||||
use_timestep_embedding: bool = False,
|
||||
freq_shift: float = 0.0,
|
||||
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
|
||||
mid_block_type: str = "UNetMidBlock1D",
|
||||
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
|
||||
mid_block_type: Tuple[str] = "UNetMidBlock1D",
|
||||
out_block_type: str = None,
|
||||
block_out_channels: Tuple[int] = (32, 32, 64),
|
||||
act_fn: str = None,
|
||||
norm_num_groups: int = 8,
|
||||
layers_per_block: int = 1,
|
||||
downsample_each_block: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_size = sample_size
|
||||
|
||||
# time
|
||||
@@ -73,12 +98,19 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
)
|
||||
timestep_input_dim = 2 * block_out_channels[0]
|
||||
elif time_embedding_type == "positional":
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
self.time_proj = Timesteps(
|
||||
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
|
||||
)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
if use_timestep_embedding:
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
self.time_mlp = TimestepEmbedding(
|
||||
in_channels=timestep_input_dim,
|
||||
time_embed_dim=time_embed_dim,
|
||||
act_fn=act_fn,
|
||||
out_dim=block_out_channels[0],
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
@@ -94,38 +126,66 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
if i == 0:
|
||||
input_channel += extra_in_channels
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=block_out_channels[0],
|
||||
add_downsample=not is_final_block or downsample_each_block,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = get_mid_block(
|
||||
mid_block_type=mid_block_type,
|
||||
mid_channels=block_out_channels[-1],
|
||||
mid_block_type,
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=None,
|
||||
mid_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
embed_dim=block_out_channels[0],
|
||||
num_layers=layers_per_block,
|
||||
add_downsample=downsample_each_block,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
if out_block_type is None:
|
||||
final_upsample_channels = out_channels
|
||||
else:
|
||||
final_upsample_channels = block_out_channels[0]
|
||||
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else out_channels
|
||||
output_channel = (
|
||||
reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
|
||||
)
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=block_out_channels[0],
|
||||
add_upsample=not is_final_block,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# TODO(PVP, Nathan) placeholder for RL application to be merged shortly
|
||||
# Totally fine to add another layer with a if statement - no need for nn.Identity here
|
||||
# out
|
||||
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
|
||||
self.out_block = get_out_block(
|
||||
out_block_type=out_block_type,
|
||||
num_groups_out=num_groups_out,
|
||||
embed_dim=block_out_channels[0],
|
||||
out_channels=out_channels,
|
||||
act_fn=act_fn,
|
||||
fc_dim=block_out_channels[-1] // 4,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -144,12 +204,20 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
[`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True,
|
||||
otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
# 1. time
|
||||
if len(timestep.shape) == 0:
|
||||
timestep = timestep[None]
|
||||
|
||||
timestep_embed = self.time_proj(timestep)[..., None]
|
||||
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
timestep_embed = self.time_proj(timesteps)
|
||||
if self.config.use_timestep_embedding:
|
||||
timestep_embed = self.time_mlp(timestep_embed)
|
||||
else:
|
||||
timestep_embed = timestep_embed[..., None]
|
||||
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
|
||||
|
||||
# 2. down
|
||||
down_block_res_samples = ()
|
||||
@@ -158,13 +226,18 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 3. mid
|
||||
sample = self.mid_block(sample)
|
||||
if self.mid_block:
|
||||
sample = self.mid_block(sample, timestep_embed)
|
||||
|
||||
# 4. up
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
res_samples = down_block_res_samples[-1:]
|
||||
down_block_res_samples = down_block_res_samples[:-1]
|
||||
sample = upsample_block(sample, res_samples)
|
||||
sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)
|
||||
|
||||
# 5. post-process
|
||||
if self.out_block:
|
||||
sample = self.out_block(sample, timestep_embed)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
@@ -17,6 +17,256 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
|
||||
|
||||
|
||||
class DownResnetBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
num_layers=1,
|
||||
conv_shortcut=False,
|
||||
temb_channels=32,
|
||||
groups=32,
|
||||
groups_out=None,
|
||||
non_linearity=None,
|
||||
time_embedding_norm="default",
|
||||
output_scale_factor=1.0,
|
||||
add_downsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.add_downsample = add_downsample
|
||||
self.output_scale_factor = output_scale_factor
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
# there will always be at least one resnet
|
||||
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)]
|
||||
|
||||
for _ in range(num_layers):
|
||||
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity == "swish":
|
||||
self.nonlinearity = lambda x: F.silu(x)
|
||||
elif non_linearity == "mish":
|
||||
self.nonlinearity = nn.Mish()
|
||||
elif non_linearity == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
else:
|
||||
self.nonlinearity = None
|
||||
|
||||
self.downsample = None
|
||||
if add_downsample:
|
||||
self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
|
||||
|
||||
def forward(self, hidden_states, temb=None):
|
||||
output_states = ()
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for resnet in self.resnets[1:]:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.nonlinearity is not None:
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
if self.downsample is not None:
|
||||
hidden_states = self.downsample(hidden_states)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class UpResnetBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
num_layers=1,
|
||||
temb_channels=32,
|
||||
groups=32,
|
||||
groups_out=None,
|
||||
non_linearity=None,
|
||||
time_embedding_norm="default",
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.add_upsample = add_upsample
|
||||
self.output_scale_factor = output_scale_factor
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
# there will always be at least one resnet
|
||||
resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
|
||||
|
||||
for _ in range(num_layers):
|
||||
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity == "swish":
|
||||
self.nonlinearity = lambda x: F.silu(x)
|
||||
elif non_linearity == "mish":
|
||||
self.nonlinearity = nn.Mish()
|
||||
elif non_linearity == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
else:
|
||||
self.nonlinearity = None
|
||||
|
||||
self.upsample = None
|
||||
if add_upsample:
|
||||
self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None):
|
||||
if res_hidden_states_tuple is not None:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for resnet in self.resnets[1:]:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.nonlinearity is not None:
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
if self.upsample is not None:
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ValueFunctionMidBlock1D(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, embed_dim):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
|
||||
self.down1 = Downsample1D(out_channels // 2, use_conv=True)
|
||||
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
|
||||
self.down2 = Downsample1D(out_channels // 4, use_conv=True)
|
||||
|
||||
def forward(self, x, temb=None):
|
||||
x = self.res1(x, temb)
|
||||
x = self.down1(x)
|
||||
x = self.res2(x, temb)
|
||||
x = self.down2(x)
|
||||
return x
|
||||
|
||||
|
||||
class MidResTemporalBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
embed_dim,
|
||||
num_layers: int = 1,
|
||||
add_downsample: bool = False,
|
||||
add_upsample: bool = False,
|
||||
non_linearity=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.add_downsample = add_downsample
|
||||
|
||||
# there will always be at least one resnet
|
||||
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
|
||||
|
||||
for _ in range(num_layers):
|
||||
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity == "swish":
|
||||
self.nonlinearity = lambda x: F.silu(x)
|
||||
elif non_linearity == "mish":
|
||||
self.nonlinearity = nn.Mish()
|
||||
elif non_linearity == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
else:
|
||||
self.nonlinearity = None
|
||||
|
||||
self.upsample = None
|
||||
if add_upsample:
|
||||
self.upsample = Downsample1D(out_channels, use_conv=True)
|
||||
|
||||
self.downsample = None
|
||||
if add_downsample:
|
||||
self.downsample = Downsample1D(out_channels, use_conv=True)
|
||||
|
||||
if self.upsample and self.downsample:
|
||||
raise ValueError("Block cannot downsample and upsample")
|
||||
|
||||
def forward(self, hidden_states, temb):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for resnet in self.resnets[1:]:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.upsample:
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
if self.downsample:
|
||||
self.downsample = self.downsample(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OutConv1DBlock(nn.Module):
|
||||
def __init__(self, num_groups_out, out_channels, embed_dim, act_fn):
|
||||
super().__init__()
|
||||
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
|
||||
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
|
||||
if act_fn == "silu":
|
||||
self.final_conv1d_act = nn.SiLU()
|
||||
if act_fn == "mish":
|
||||
self.final_conv1d_act = nn.Mish()
|
||||
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
|
||||
|
||||
def forward(self, hidden_states, temb=None):
|
||||
hidden_states = self.final_conv1d_1(hidden_states)
|
||||
hidden_states = rearrange_dims(hidden_states)
|
||||
hidden_states = self.final_conv1d_gn(hidden_states)
|
||||
hidden_states = rearrange_dims(hidden_states)
|
||||
hidden_states = self.final_conv1d_act(hidden_states)
|
||||
hidden_states = self.final_conv1d_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OutValueFunctionBlock(nn.Module):
|
||||
def __init__(self, fc_dim, embed_dim):
|
||||
super().__init__()
|
||||
self.final_block = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(fc_dim + embed_dim, fc_dim // 2),
|
||||
nn.Mish(),
|
||||
nn.Linear(fc_dim // 2, 1),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, temb):
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], -1)
|
||||
hidden_states = torch.cat((hidden_states, temb), dim=-1)
|
||||
for layer in self.final_block:
|
||||
hidden_states = layer(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
_kernels = {
|
||||
"linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
|
||||
@@ -62,7 +312,7 @@ class Upsample1d(nn.Module):
|
||||
self.pad = kernel_1d.shape[0] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
def forward(self, hidden_states, temb=None):
|
||||
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
|
||||
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
|
||||
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
|
||||
@@ -162,32 +412,6 @@ class ResConvBlock(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
def get_down_block(down_block_type, out_channels, in_channels):
|
||||
if down_block_type == "DownBlock1D":
|
||||
return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
|
||||
elif down_block_type == "AttnDownBlock1D":
|
||||
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
|
||||
elif down_block_type == "DownBlock1DNoSkip":
|
||||
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
|
||||
raise ValueError(f"{down_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_up_block(up_block_type, in_channels, out_channels):
|
||||
if up_block_type == "UpBlock1D":
|
||||
return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
|
||||
elif up_block_type == "AttnUpBlock1D":
|
||||
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
|
||||
elif up_block_type == "UpBlock1DNoSkip":
|
||||
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
|
||||
raise ValueError(f"{up_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_mid_block(mid_block_type, in_channels, mid_channels, out_channels):
|
||||
if mid_block_type == "UNetMidBlock1D":
|
||||
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
|
||||
raise ValueError(f"{mid_block_type} does not exist.")
|
||||
|
||||
|
||||
class UNetMidBlock1D(nn.Module):
|
||||
def __init__(self, mid_channels, in_channels, out_channels=None):
|
||||
super().__init__()
|
||||
@@ -217,7 +441,7 @@ class UNetMidBlock1D(nn.Module):
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
def forward(self, hidden_states, temb=None):
|
||||
hidden_states = self.down(hidden_states)
|
||||
for attn, resnet in zip(self.attentions, self.resnets):
|
||||
hidden_states = resnet(hidden_states)
|
||||
@@ -322,7 +546,7 @@ class AttnUpBlock1D(nn.Module):
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.up = Upsample1d(kernel="cubic")
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple):
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
@@ -349,7 +573,7 @@ class UpBlock1D(nn.Module):
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.up = Upsample1d(kernel="cubic")
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple):
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
@@ -374,7 +598,7 @@ class UpBlock1DNoSkip(nn.Module):
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple):
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
@@ -382,3 +606,63 @@ class UpBlock1DNoSkip(nn.Module):
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample):
|
||||
if down_block_type == "DownResnetBlock1D":
|
||||
return DownResnetBlock1D(
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_downsample=add_downsample,
|
||||
)
|
||||
elif down_block_type == "DownBlock1D":
|
||||
return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
|
||||
elif down_block_type == "AttnDownBlock1D":
|
||||
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
|
||||
elif down_block_type == "DownBlock1DNoSkip":
|
||||
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
|
||||
raise ValueError(f"{down_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_channels, add_upsample):
|
||||
if up_block_type == "UpResnetBlock1D":
|
||||
return UpResnetBlock1D(
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_upsample=add_upsample,
|
||||
)
|
||||
elif up_block_type == "UpBlock1D":
|
||||
return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
|
||||
elif up_block_type == "AttnUpBlock1D":
|
||||
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
|
||||
elif up_block_type == "UpBlock1DNoSkip":
|
||||
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
|
||||
raise ValueError(f"{up_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels, embed_dim, add_downsample):
|
||||
if mid_block_type == "MidResTemporalBlock1D":
|
||||
return MidResTemporalBlock1D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
embed_dim=embed_dim,
|
||||
add_downsample=add_downsample,
|
||||
)
|
||||
elif mid_block_type == "ValueFunctionMidBlock1D":
|
||||
return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
|
||||
elif mid_block_type == "UNetMidBlock1D":
|
||||
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
|
||||
raise ValueError(f"{mid_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim):
|
||||
if out_block_type == "OutConv1DBlock":
|
||||
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
|
||||
elif out_block_type == "ValueFunction":
|
||||
return OutValueFunctionBlock(fc_dim, embed_dim)
|
||||
return None
|
||||
|
||||
@@ -51,7 +51,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
|
||||
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to :
|
||||
obj:`False`): Whether to flip sin to cos for fourier time embedding.
|
||||
obj:`True`): Whether to flip sin to cos for fourier time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to :
|
||||
obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
|
||||
types.
|
||||
|
||||
@@ -60,7 +60,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
|
||||
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
|
||||
@@ -47,7 +47,7 @@ logger = logging.get_logger(__name__)
|
||||
LOADABLE_CLASSES = {
|
||||
"diffusers": {
|
||||
"FlaxModelMixin": ["save_pretrained", "from_pretrained"],
|
||||
"FlaxSchedulerMixin": ["save_config", "from_config"],
|
||||
"FlaxSchedulerMixin": ["save_pretrained", "from_pretrained"],
|
||||
"FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"],
|
||||
},
|
||||
"transformers": {
|
||||
@@ -280,7 +280,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
>>> from diffusers import FlaxDPMSolverMultistepScheduler
|
||||
|
||||
>>> model_id = "runwayml/stable-diffusion-v1-5"
|
||||
>>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_config(
|
||||
>>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
|
||||
... model_id,
|
||||
... subfolder="scheduler",
|
||||
... )
|
||||
@@ -303,7 +303,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
# 1. Download the checkpoints and configs
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
if not os.path.isdir(pretrained_model_name_or_path):
|
||||
config_dict = cls.get_config_dict(
|
||||
config_dict = cls.load_config(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
@@ -349,7 +349,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
else:
|
||||
cached_folder = pretrained_model_name_or_path
|
||||
|
||||
config_dict = cls.get_config_dict(cached_folder)
|
||||
config_dict = cls.load_config(cached_folder)
|
||||
|
||||
# 2. Load the pipeline class, if using custom module then load it from the hub
|
||||
# if we load from explicit class, let's use it
|
||||
@@ -370,7 +370,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
|
||||
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
init_kwargs = {}
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ logger = logging.get_logger(__name__)
|
||||
LOADABLE_CLASSES = {
|
||||
"diffusers": {
|
||||
"ModelMixin": ["save_pretrained", "from_pretrained"],
|
||||
"SchedulerMixin": ["save_config", "from_config"],
|
||||
"SchedulerMixin": ["save_pretrained", "from_pretrained"],
|
||||
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
|
||||
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
|
||||
},
|
||||
@@ -207,7 +207,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
if torch_device is None:
|
||||
return self
|
||||
|
||||
module_names, _ = self.extract_init_dict(dict(self.config))
|
||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||
for name in module_names.keys():
|
||||
module = getattr(self, name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
@@ -228,7 +228,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
Returns:
|
||||
`torch.device`: The torch device on which the pipeline is located.
|
||||
"""
|
||||
module_names, _ = self.extract_init_dict(dict(self.config))
|
||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||
for name in module_names.keys():
|
||||
module = getattr(self, name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
@@ -377,11 +377,11 @@ class DiffusionPipeline(ConfigMixin):
|
||||
>>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
|
||||
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
|
||||
>>> # Download pipeline, but overwrite scheduler
|
||||
>>> # Use a different scheduler
|
||||
>>> from diffusers import LMSDiscreteScheduler
|
||||
|
||||
>>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
|
||||
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler)
|
||||
>>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
>>> pipeline.scheduler = scheduler
|
||||
```
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
@@ -428,7 +428,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
# 1. Download the checkpoints and configs
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
if not os.path.isdir(pretrained_model_name_or_path):
|
||||
config_dict = cls.get_config_dict(
|
||||
config_dict = cls.load_config(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
@@ -474,7 +474,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
else:
|
||||
cached_folder = pretrained_model_name_or_path
|
||||
|
||||
config_dict = cls.get_config_dict(cached_folder)
|
||||
config_dict = cls.load_config(cached_folder)
|
||||
|
||||
# 2. Load the pipeline class, if using custom module then load it from the hub
|
||||
# if we load from explicit class, let's use it
|
||||
@@ -513,7 +513,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"])
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
|
||||
init_dict, unused_kwargs = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(f"Keyword arguments {unused_kwargs} not recognized.")
|
||||
|
||||
@@ -40,7 +40,7 @@ available a colab notebook to directly try them out.
|
||||
| [pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | *Unconditional Image Generation* |
|
||||
| [score_sde_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* |
|
||||
| [score_sde_vp](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* |
|
||||
| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-to-Image Generation* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
|
||||
| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-to-Image Generation* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb)
|
||||
| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Image-to-Image Text-Guided Generation* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
|
||||
| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-Guided Image Inpainting* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
|
||||
| [stochastic_karras_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | *Unconditional Image Generation* |
|
||||
|
||||
@@ -71,7 +71,7 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
"""
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
|
||||
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ image.save("astronaut_rides_horse.png")
|
||||
# make sure you're logged in with `huggingface-cli login`
|
||||
from diffusers import StableDiffusionPipeline, DDIMScheduler
|
||||
|
||||
scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||
scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
@@ -91,7 +91,7 @@ image.save("astronaut_rides_horse.png")
|
||||
# make sure you're logged in with `huggingface-cli login`
|
||||
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
|
||||
|
||||
lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||
lms = LMSDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
@@ -120,7 +120,7 @@ from diffusers import CycleDiffusionPipeline, DDIMScheduler
|
||||
# load the pipeline
|
||||
# make sure you're logged in with `huggingface-cli login`
|
||||
model_id_or_path = "CompVis/stable-diffusion-v1-4"
|
||||
scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler")
|
||||
scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler")
|
||||
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")
|
||||
|
||||
# let's download an initial image
|
||||
|
||||
@@ -23,7 +23,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@@ -82,8 +82,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2010.02502
|
||||
|
||||
@@ -109,14 +109,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatible_classes = [
|
||||
"PNDMScheduler",
|
||||
"DDPMScheduler",
|
||||
"LMSDiscreteScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
]
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -23,7 +23,12 @@ import flax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
|
||||
from .scheduling_utils_flax import (
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
|
||||
FlaxSchedulerMixin,
|
||||
FlaxSchedulerOutput,
|
||||
broadcast_to_shape_from_left,
|
||||
)
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
|
||||
@@ -79,8 +84,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2010.02502
|
||||
|
||||
@@ -105,6 +110,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
stable diffusion.
|
||||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
return True
|
||||
|
||||
@@ -22,7 +22,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
||||
from ..utils import BaseOutput, deprecate
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@@ -80,8 +80,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2006.11239
|
||||
|
||||
@@ -104,14 +104,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatible_classes = [
|
||||
"DDIMScheduler",
|
||||
"PNDMScheduler",
|
||||
"LMSDiscreteScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
]
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -204,6 +197,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# for rl-diffuser https://arxiv.org/abs/2205.09991
|
||||
elif variance_type == "fixed_small_log":
|
||||
variance = torch.log(torch.clamp(variance, min=1e-20))
|
||||
variance = torch.exp(0.5 * variance)
|
||||
elif variance_type == "fixed_large":
|
||||
variance = self.betas[t]
|
||||
elif variance_type == "fixed_large_log":
|
||||
@@ -248,7 +242,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
|
||||
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
|
||||
@@ -301,7 +295,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
variance_noise = torch.randn(
|
||||
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
|
||||
)
|
||||
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
|
||||
if self.variance_type == "fixed_small_log":
|
||||
variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
|
||||
else:
|
||||
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
|
||||
|
||||
pred_prev_sample = pred_prev_sample + variance
|
||||
|
||||
|
||||
@@ -24,7 +24,12 @@ from jax import random
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
||||
from ..utils import deprecate
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
|
||||
from .scheduling_utils_flax import (
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
|
||||
FlaxSchedulerMixin,
|
||||
FlaxSchedulerOutput,
|
||||
broadcast_to_shape_from_left,
|
||||
)
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
|
||||
@@ -79,8 +84,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2006.11239
|
||||
|
||||
@@ -103,6 +108,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
return True
|
||||
@@ -221,7 +228,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
|
||||
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
|
||||
|
||||
@@ -21,6 +21,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
@@ -71,8 +72,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
@@ -116,14 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatible_classes = [
|
||||
"DDIMScheduler",
|
||||
"DDPMScheduler",
|
||||
"PNDMScheduler",
|
||||
"LMSDiscreteScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
]
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -23,7 +23,12 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
|
||||
from .scheduling_utils_flax import (
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
|
||||
FlaxSchedulerMixin,
|
||||
FlaxSchedulerOutput,
|
||||
broadcast_to_shape_from_left,
|
||||
)
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
|
||||
@@ -96,8 +101,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
|
||||
|
||||
@@ -143,6 +148,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
return True
|
||||
|
||||
@@ -19,7 +19,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, logging
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@@ -52,8 +52,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
@@ -67,14 +67,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatible_classes = [
|
||||
"DDIMScheduler",
|
||||
"DDPMScheduler",
|
||||
"LMSDiscreteScheduler",
|
||||
"PNDMScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
]
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -19,7 +19,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, logging
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@@ -53,8 +53,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
@@ -68,14 +68,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatible_classes = [
|
||||
"DDIMScheduler",
|
||||
"DDPMScheduler",
|
||||
"LMSDiscreteScheduler",
|
||||
"PNDMScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
]
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -28,8 +28,8 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2202.09778
|
||||
|
||||
|
||||
@@ -56,8 +56,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
|
||||
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
|
||||
|
||||
@@ -67,8 +67,8 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
|
||||
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch
|
||||
from scipy import integrate
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@@ -52,8 +52,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
@@ -67,14 +67,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatible_classes = [
|
||||
"DDIMScheduler",
|
||||
"DDPMScheduler",
|
||||
"PNDMScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
]
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -20,7 +20,12 @@ import jax.numpy as jnp
|
||||
from scipy import integrate
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
|
||||
from .scheduling_utils_flax import (
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
|
||||
FlaxSchedulerMixin,
|
||||
FlaxSchedulerOutput,
|
||||
broadcast_to_shape_from_left,
|
||||
)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
@@ -49,8 +54,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
@@ -63,6 +68,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
return True
|
||||
|
||||
@@ -21,6 +21,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
@@ -60,8 +61,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2202.09778
|
||||
|
||||
@@ -88,14 +89,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatible_classes = [
|
||||
"DDIMScheduler",
|
||||
"DDPMScheduler",
|
||||
"LMSDiscreteScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
]
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -23,7 +23,12 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
|
||||
from .scheduling_utils_flax import (
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
|
||||
FlaxSchedulerMixin,
|
||||
FlaxSchedulerOutput,
|
||||
broadcast_to_shape_from_left,
|
||||
)
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
|
||||
@@ -87,8 +92,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2202.09778
|
||||
|
||||
@@ -114,6 +119,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
stable diffusion.
|
||||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
return True
|
||||
|
||||
@@ -77,8 +77,8 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/pdf/2201.09865.pdf
|
||||
|
||||
|
||||
@@ -50,8 +50,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
|
||||
@@ -64,8 +64,8 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
|
||||
@@ -29,8 +29,8 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more information, see the original paper: https://arxiv.org/abs/2011.13456
|
||||
|
||||
|
||||
@@ -11,7 +11,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -38,6 +41,114 @@ class SchedulerOutput(BaseOutput):
|
||||
class SchedulerMixin:
|
||||
"""
|
||||
Mixin containing common functions for the schedulers.
|
||||
|
||||
Class attributes:
|
||||
- **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
|
||||
`from_config` can be used from a class different than the one used to save the config (should be overridden
|
||||
by parent class).
|
||||
"""
|
||||
|
||||
config_name = SCHEDULER_CONFIG_NAME
|
||||
_compatibles = []
|
||||
has_compatibles = True
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Dict[str, Any] = None,
|
||||
subfolder: Optional[str] = None,
|
||||
return_unused_kwargs=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Instantiate a Scheduler class from a pre-defined JSON configuration file inside a directory or Hub repo.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
|
||||
organization name, like `google/ddpm-celebahq-256`.
|
||||
- A path to a *directory* containing the schedluer configurations saved using
|
||||
[`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`.
|
||||
subfolder (`str`, *optional*):
|
||||
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
||||
huggingface.co or downloaded locally), you can specify the folder name here.
|
||||
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
||||
Whether kwargs that are not consumed by the Python class should be returned or not.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
||||
file exists.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
local_files_only(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to only look at local files (i.e., do not try to download the model).
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `transformers-cli login` (stored in `~/.huggingface`).
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
|
||||
<Tip>
|
||||
|
||||
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
||||
models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
|
||||
use this method in a firewalled environment.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
config, kwargs = cls.load_config(
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
subfolder=subfolder,
|
||||
return_unused_kwargs=True,
|
||||
**kwargs,
|
||||
)
|
||||
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
||||
"""
|
||||
Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
|
||||
[`~SchedulerMixin.from_pretrained`] class method.
|
||||
|
||||
Args:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
||||
"""
|
||||
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||||
|
||||
@property
|
||||
def compatibles(self):
|
||||
"""
|
||||
Returns all schedulers that are compatible with this scheduler
|
||||
|
||||
Returns:
|
||||
`List[SchedulerMixin]`: List of compatible schedulers
|
||||
"""
|
||||
return self._get_compatibles()
|
||||
|
||||
@classmethod
|
||||
def _get_compatibles(cls):
|
||||
compatible_classes_str = list(set([cls.__name__] + cls._compatibles))
|
||||
diffusers_library = importlib.import_module(__name__.split(".")[0])
|
||||
compatible_classes = [
|
||||
getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
|
||||
]
|
||||
return compatible_classes
|
||||
|
||||
@@ -11,15 +11,18 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..utils import BaseOutput
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
|
||||
|
||||
|
||||
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = ["Flax" + c for c in _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -39,9 +42,123 @@ class FlaxSchedulerOutput(BaseOutput):
|
||||
class FlaxSchedulerMixin:
|
||||
"""
|
||||
Mixin containing common functions for the schedulers.
|
||||
|
||||
Class attributes:
|
||||
- **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
|
||||
`from_config` can be used from a class different than the one used to save the config (should be overridden
|
||||
by parent class).
|
||||
"""
|
||||
|
||||
config_name = SCHEDULER_CONFIG_NAME
|
||||
_compatibles = []
|
||||
has_compatibles = True
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Dict[str, Any] = None,
|
||||
subfolder: Optional[str] = None,
|
||||
return_unused_kwargs=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Instantiate a Scheduler class from a pre-defined JSON-file.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
|
||||
organization name, like `google/ddpm-celebahq-256`.
|
||||
- A path to a *directory* containing model weights saved using [`~SchedulerMixin.save_pretrained`],
|
||||
e.g., `./my_model_directory/`.
|
||||
subfolder (`str`, *optional*):
|
||||
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
||||
huggingface.co or downloaded locally), you can specify the folder name here.
|
||||
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
||||
Whether kwargs that are not consumed by the Python class should be returned or not.
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
||||
file exists.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
local_files_only(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to only look at local files (i.e., do not try to download the model).
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `transformers-cli login` (stored in `~/.huggingface`).
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
|
||||
<Tip>
|
||||
|
||||
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
||||
models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
|
||||
use this method in a firewalled environment.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
config, kwargs = cls.load_config(
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
|
||||
)
|
||||
scheduler, unused_kwargs = cls.from_config(config, return_unused_kwargs=True, **kwargs)
|
||||
|
||||
if hasattr(scheduler, "create_state") and getattr(scheduler, "has_state", False):
|
||||
state = scheduler.create_state()
|
||||
|
||||
if return_unused_kwargs:
|
||||
return scheduler, state, unused_kwargs
|
||||
|
||||
return scheduler, state
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
||||
"""
|
||||
Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
|
||||
[`~FlaxSchedulerMixin.from_pretrained`] class method.
|
||||
|
||||
Args:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
||||
"""
|
||||
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||||
|
||||
@property
|
||||
def compatibles(self):
|
||||
"""
|
||||
Returns all schedulers that are compatible with this scheduler
|
||||
|
||||
Returns:
|
||||
`List[SchedulerMixin]`: List of compatible schedulers
|
||||
"""
|
||||
return self._get_compatibles()
|
||||
|
||||
@classmethod
|
||||
def _get_compatibles(cls):
|
||||
compatible_classes_str = list(set([cls.__name__] + cls._compatibles))
|
||||
diffusers_library = importlib.import_module(__name__.split(".")[0])
|
||||
compatible_classes = [
|
||||
getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
|
||||
]
|
||||
return compatible_classes
|
||||
|
||||
|
||||
def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray:
|
||||
|
||||
@@ -112,8 +112,8 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2111.14822
|
||||
|
||||
|
||||
@@ -72,3 +72,13 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
|
||||
DIFFUSERS_CACHE = default_cache_path
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
|
||||
|
||||
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [
|
||||
"DDIMScheduler",
|
||||
"DDPMScheduler",
|
||||
"PNDMScheduler",
|
||||
"LMSDiscreteScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
]
|
||||
|
||||
@@ -18,13 +18,120 @@ import unittest
|
||||
import torch
|
||||
|
||||
from diffusers import UNet1DModel
|
||||
from diffusers.utils import slow, torch_device
|
||||
from diffusers.utils import floats_tensor, slow, torch_device
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class UnetModel1DTests(unittest.TestCase):
|
||||
class UNet1DModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNet1DModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_features = 14
|
||||
seq_len = 16
|
||||
|
||||
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
|
||||
time_step = torch.tensor([10] * batch_size).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 14, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 14, 16)
|
||||
|
||||
def test_ema_training(self):
|
||||
pass
|
||||
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_determinism(self):
|
||||
super().test_determinism()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_outputs_equivalence(self):
|
||||
super().test_outputs_equivalence()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
super().test_from_pretrained_save_pretrained()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_model_from_pretrained(self):
|
||||
super().test_model_from_pretrained()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_output(self):
|
||||
super().test_output()
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_out_channels": (32, 64, 128, 256),
|
||||
"in_channels": 14,
|
||||
"out_channels": 14,
|
||||
"time_embedding_type": "positional",
|
||||
"use_timestep_embedding": True,
|
||||
"flip_sin_to_cos": False,
|
||||
"freq_shift": 1.0,
|
||||
"out_block_type": "OutConv1DBlock",
|
||||
"mid_block_type": "MidResTemporalBlock1D",
|
||||
"down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
|
||||
"up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"),
|
||||
"act_fn": "mish",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = UNet1DModel.from_pretrained(
|
||||
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet"
|
||||
)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
model.to(torch_device)
|
||||
image = model(**self.dummy_input)
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_output_pretrained(self):
|
||||
model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet")
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
num_features = model.in_channels
|
||||
seq_len = 16
|
||||
noise = torch.randn((1, seq_len, num_features)).permute(
|
||||
0, 2, 1
|
||||
) # match original, we can update values and remove
|
||||
time_step = torch.full((num_features,), 0)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step).sample.permute(0, 2, 1)
|
||||
|
||||
output_slice = output[0, -3:, -3:].flatten()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-2.137172, 1.1426016, 0.3688687, -0.766922, 0.7303146, 0.11038864, -0.4760633, 0.13270172, 0.02591348])
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
def test_forward_with_norm_groups(self):
|
||||
# Not implemented yet for this UNet
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_unet_1d_maestro(self):
|
||||
model_id = "harmonai/maestro-150k"
|
||||
@@ -43,3 +150,127 @@ class UnetModel1DTests(unittest.TestCase):
|
||||
|
||||
assert (output_sum - 224.0896).abs() < 4e-2
|
||||
assert (output_max - 0.0607).abs() < 4e-4
|
||||
|
||||
|
||||
class UNetRLModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNet1DModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_features = 14
|
||||
seq_len = 16
|
||||
|
||||
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
|
||||
time_step = torch.tensor([10] * batch_size).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 14, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 14, 1)
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_determinism(self):
|
||||
super().test_determinism()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_outputs_equivalence(self):
|
||||
super().test_outputs_equivalence()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
super().test_from_pretrained_save_pretrained()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_model_from_pretrained(self):
|
||||
super().test_model_from_pretrained()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_output(self):
|
||||
# UNetRL is a value-function is different output shape
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1))
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_ema_training(self):
|
||||
pass
|
||||
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 14,
|
||||
"out_channels": 14,
|
||||
"down_block_types": ["DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"],
|
||||
"up_block_types": [],
|
||||
"out_block_type": "ValueFunction",
|
||||
"mid_block_type": "ValueFunctionMidBlock1D",
|
||||
"block_out_channels": [32, 64, 128, 256],
|
||||
"layers_per_block": 1,
|
||||
"downsample_each_block": True,
|
||||
"use_timestep_embedding": True,
|
||||
"freq_shift": 1.0,
|
||||
"flip_sin_to_cos": False,
|
||||
"time_embedding_type": "positional",
|
||||
"act_fn": "mish",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_from_pretrained_hub(self):
|
||||
value_function, vf_loading_info = UNet1DModel.from_pretrained(
|
||||
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
|
||||
)
|
||||
self.assertIsNotNone(value_function)
|
||||
self.assertEqual(len(vf_loading_info["missing_keys"]), 0)
|
||||
|
||||
value_function.to(torch_device)
|
||||
image = value_function(**self.dummy_input)
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_output_pretrained(self):
|
||||
value_function, vf_loading_info = UNet1DModel.from_pretrained(
|
||||
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
num_features = value_function.in_channels
|
||||
seq_len = 14
|
||||
noise = torch.randn((1, seq_len, num_features)).permute(
|
||||
0, 2, 1
|
||||
) # match original, we can update values and remove
|
||||
time_step = torch.full((num_features,), 0)
|
||||
|
||||
with torch.no_grad():
|
||||
output = value_function(noise, time_step).sample
|
||||
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([165.25] * seq_len)
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3))
|
||||
|
||||
def test_forward_with_norm_groups(self):
|
||||
# Not implemented yet for this UNet
|
||||
pass
|
||||
|
||||
@@ -44,6 +44,10 @@ class PipelineFastTests(unittest.TestCase):
|
||||
sample_rate=16_000,
|
||||
in_channels=2,
|
||||
out_channels=2,
|
||||
flip_sin_to_cos=True,
|
||||
use_timestep_embedding=False,
|
||||
time_embedding_type="fourier",
|
||||
mid_block_type="UNetMidBlock1D",
|
||||
down_block_types=["DownBlock1DNoSkip"] + ["DownBlock1D"] + ["AttnDownBlock1D"],
|
||||
up_block_types=["AttnUpBlock1D"] + ["UpBlock1D"] + ["UpBlock1DNoSkip"],
|
||||
)
|
||||
|
||||
@@ -75,7 +75,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
|
||||
model_id = "google/ddpm-ema-bedroom-256"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = DDIMScheduler.from_config(model_id)
|
||||
scheduler = DDIMScheduler.from_pretrained(model_id)
|
||||
|
||||
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||
ddpm.to(torch_device)
|
||||
|
||||
@@ -106,7 +106,7 @@ class DDPMPipelineIntegrationTests(unittest.TestCase):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = DDPMScheduler.from_config(model_id)
|
||||
scheduler = DDPMScheduler.from_pretrained(model_id)
|
||||
|
||||
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
|
||||
ddpm.to(torch_device)
|
||||
|
||||
@@ -44,7 +44,7 @@ class RepaintPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
model_id = "google/ddpm-ema-celebahq-256"
|
||||
unet = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = RePaintScheduler.from_config(model_id)
|
||||
scheduler = RePaintScheduler.from_pretrained(model_id)
|
||||
|
||||
repaint = RePaintPipeline(unet=unet, scheduler=scheduler).to(torch_device)
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ class ScoreSdeVePipelineIntegrationTests(unittest.TestCase):
|
||||
model_id = "google/ncsnpp-church-256"
|
||||
model = UNet2DModel.from_pretrained(model_id)
|
||||
|
||||
scheduler = ScoreSdeVeScheduler.from_config(model_id)
|
||||
scheduler = ScoreSdeVeScheduler.from_pretrained(model_id)
|
||||
|
||||
sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
|
||||
sde_ve.to(torch_device)
|
||||
|
||||
@@ -281,7 +281,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
init_image = init_image.resize((512, 512))
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
scheduler = DDIMScheduler.from_config(model_id, subfolder="scheduler")
|
||||
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
pipe = CycleDiffusionPipeline.from_pretrained(
|
||||
model_id, scheduler=scheduler, safety_checker=None, torch_dtype=torch.float16, revision="fp16"
|
||||
)
|
||||
@@ -322,7 +322,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
init_image = init_image.resize((512, 512))
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
scheduler = DDIMScheduler.from_config(model_id, subfolder="scheduler")
|
||||
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
pipe = CycleDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, safety_checker=None)
|
||||
|
||||
pipe.to(torch_device)
|
||||
|
||||
@@ -75,7 +75,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_inference_ddim(self):
|
||||
ddim_scheduler = DDIMScheduler.from_config(
|
||||
ddim_scheduler = DDIMScheduler.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx"
|
||||
)
|
||||
sd_pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
||||
@@ -98,7 +98,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_inference_k_lms(self):
|
||||
lms_scheduler = LMSDiscreteScheduler.from_config(
|
||||
lms_scheduler = LMSDiscreteScheduler.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx"
|
||||
)
|
||||
sd_pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
||||
|
||||
@@ -93,7 +93,7 @@ class OnnxStableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
"/img2img/sketch-mountains-input.jpg"
|
||||
)
|
||||
init_image = init_image.resize((768, 512))
|
||||
lms_scheduler = LMSDiscreteScheduler.from_config(
|
||||
lms_scheduler = LMSDiscreteScheduler.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx"
|
||||
)
|
||||
pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
|
||||
@@ -97,7 +97,7 @@ class OnnxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
)
|
||||
lms_scheduler = LMSDiscreteScheduler.from_config(
|
||||
lms_scheduler = LMSDiscreteScheduler.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", subfolder="scheduler", revision="onnx"
|
||||
)
|
||||
pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained(
|
||||
|
||||
@@ -703,7 +703,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_fast_ddim(self):
|
||||
scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-1", subfolder="scheduler")
|
||||
scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-1", subfolder="scheduler")
|
||||
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", scheduler=scheduler)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
@@ -726,7 +726,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
model_id = "CompVis/stable-diffusion-v1-1"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_id).to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
|
||||
scheduler = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
pipe.scheduler = scheduler
|
||||
|
||||
prompt = "a photograph of an astronaut riding a horse"
|
||||
|
||||
@@ -520,7 +520,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
|
||||
lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
model_id,
|
||||
scheduler=lms,
|
||||
@@ -557,7 +557,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
ddim = DDIMScheduler.from_config(model_id, subfolder="scheduler")
|
||||
ddim = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
model_id,
|
||||
scheduler=ddim,
|
||||
@@ -649,7 +649,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
init_image = init_image.resize((768, 512))
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
|
||||
lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
model_id, scheduler=lms, safety_checker=None, device_map="auto", revision="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
@@ -400,7 +400,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
model_id = "runwayml/stable-diffusion-inpainting"
|
||||
pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler")
|
||||
pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None, scheduler=pndm)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
@@ -437,7 +437,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
model_id = "runwayml/stable-diffusion-inpainting"
|
||||
pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler")
|
||||
pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
model_id,
|
||||
safety_checker=None,
|
||||
|
||||
@@ -401,7 +401,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
|
||||
lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
model_id,
|
||||
scheduler=lms,
|
||||
|
||||
@@ -13,12 +13,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import diffusers
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
@@ -81,7 +78,7 @@ class SampleObject3(ConfigMixin):
|
||||
class ConfigTester(unittest.TestCase):
|
||||
def test_load_not_from_mixin(self):
|
||||
with self.assertRaises(ValueError):
|
||||
ConfigMixin.from_config("dummy_path")
|
||||
ConfigMixin.load_config("dummy_path")
|
||||
|
||||
def test_register_to_config(self):
|
||||
obj = SampleObject()
|
||||
@@ -131,7 +128,7 @@ class ConfigTester(unittest.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
obj.save_config(tmpdirname)
|
||||
new_obj = SampleObject.from_config(tmpdirname)
|
||||
new_obj = SampleObject.from_config(SampleObject.load_config(tmpdirname))
|
||||
new_config = new_obj.config
|
||||
|
||||
# unfreeze configs
|
||||
@@ -142,117 +139,13 @@ class ConfigTester(unittest.TestCase):
|
||||
assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json
|
||||
assert config == new_config
|
||||
|
||||
def test_save_load_from_different_config(self):
|
||||
obj = SampleObject()
|
||||
|
||||
# mock add obj class to `diffusers`
|
||||
setattr(diffusers, "SampleObject", SampleObject)
|
||||
logger = logging.get_logger("diffusers.configuration_utils")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
obj.save_config(tmpdirname)
|
||||
with CaptureLogger(logger) as cap_logger_1:
|
||||
new_obj_1 = SampleObject2.from_config(tmpdirname)
|
||||
|
||||
# now save a config parameter that is not expected
|
||||
with open(os.path.join(tmpdirname, SampleObject.config_name), "r") as f:
|
||||
data = json.load(f)
|
||||
data["unexpected"] = True
|
||||
|
||||
with open(os.path.join(tmpdirname, SampleObject.config_name), "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger_2:
|
||||
new_obj_2 = SampleObject.from_config(tmpdirname)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger_3:
|
||||
new_obj_3 = SampleObject2.from_config(tmpdirname)
|
||||
|
||||
assert new_obj_1.__class__ == SampleObject2
|
||||
assert new_obj_2.__class__ == SampleObject
|
||||
assert new_obj_3.__class__ == SampleObject2
|
||||
|
||||
assert cap_logger_1.out == ""
|
||||
assert (
|
||||
cap_logger_2.out
|
||||
== "The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will"
|
||||
" be ignored. Please verify your config.json configuration file.\n"
|
||||
)
|
||||
assert cap_logger_2.out.replace("SampleObject", "SampleObject2") == cap_logger_3.out
|
||||
|
||||
def test_save_load_compatible_schedulers(self):
|
||||
SampleObject2._compatible_classes = ["SampleObject"]
|
||||
SampleObject._compatible_classes = ["SampleObject2"]
|
||||
|
||||
obj = SampleObject()
|
||||
|
||||
# mock add obj class to `diffusers`
|
||||
setattr(diffusers, "SampleObject", SampleObject)
|
||||
setattr(diffusers, "SampleObject2", SampleObject2)
|
||||
logger = logging.get_logger("diffusers.configuration_utils")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
obj.save_config(tmpdirname)
|
||||
|
||||
# now save a config parameter that is expected by another class, but not origin class
|
||||
with open(os.path.join(tmpdirname, SampleObject.config_name), "r") as f:
|
||||
data = json.load(f)
|
||||
data["f"] = [0, 0]
|
||||
data["unexpected"] = True
|
||||
|
||||
with open(os.path.join(tmpdirname, SampleObject.config_name), "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
new_obj = SampleObject.from_config(tmpdirname)
|
||||
|
||||
assert new_obj.__class__ == SampleObject
|
||||
|
||||
assert (
|
||||
cap_logger.out
|
||||
== "The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will"
|
||||
" be ignored. Please verify your config.json configuration file.\n"
|
||||
)
|
||||
|
||||
def test_save_load_from_different_config_comp_schedulers(self):
|
||||
SampleObject3._compatible_classes = ["SampleObject", "SampleObject2"]
|
||||
SampleObject2._compatible_classes = ["SampleObject", "SampleObject3"]
|
||||
SampleObject._compatible_classes = ["SampleObject2", "SampleObject3"]
|
||||
|
||||
obj = SampleObject()
|
||||
|
||||
# mock add obj class to `diffusers`
|
||||
setattr(diffusers, "SampleObject", SampleObject)
|
||||
setattr(diffusers, "SampleObject2", SampleObject2)
|
||||
setattr(diffusers, "SampleObject3", SampleObject3)
|
||||
logger = logging.get_logger("diffusers.configuration_utils")
|
||||
logger.setLevel(diffusers.logging.INFO)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
obj.save_config(tmpdirname)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger_1:
|
||||
new_obj_1 = SampleObject.from_config(tmpdirname)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger_2:
|
||||
new_obj_2 = SampleObject2.from_config(tmpdirname)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger_3:
|
||||
new_obj_3 = SampleObject3.from_config(tmpdirname)
|
||||
|
||||
assert new_obj_1.__class__ == SampleObject
|
||||
assert new_obj_2.__class__ == SampleObject2
|
||||
assert new_obj_3.__class__ == SampleObject3
|
||||
|
||||
assert cap_logger_1.out == ""
|
||||
assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
|
||||
assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
|
||||
|
||||
def test_load_ddim_from_pndm(self):
|
||||
logger = logging.get_logger("diffusers.configuration_utils")
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
ddim = DDIMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
|
||||
ddim = DDIMScheduler.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
|
||||
)
|
||||
|
||||
assert ddim.__class__ == DDIMScheduler
|
||||
# no warning should be thrown
|
||||
@@ -262,7 +155,7 @@ class ConfigTester(unittest.TestCase):
|
||||
logger = logging.get_logger("diffusers.configuration_utils")
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
euler = EulerDiscreteScheduler.from_config(
|
||||
euler = EulerDiscreteScheduler.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
|
||||
)
|
||||
|
||||
@@ -274,7 +167,7 @@ class ConfigTester(unittest.TestCase):
|
||||
logger = logging.get_logger("diffusers.configuration_utils")
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
euler = EulerAncestralDiscreteScheduler.from_config(
|
||||
euler = EulerAncestralDiscreteScheduler.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
|
||||
)
|
||||
|
||||
@@ -286,7 +179,9 @@ class ConfigTester(unittest.TestCase):
|
||||
logger = logging.get_logger("diffusers.configuration_utils")
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
|
||||
pndm = PNDMScheduler.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
|
||||
)
|
||||
|
||||
assert pndm.__class__ == PNDMScheduler
|
||||
# no warning should be thrown
|
||||
@@ -296,7 +191,7 @@ class ConfigTester(unittest.TestCase):
|
||||
logger = logging.get_logger("diffusers.configuration_utils")
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
ddpm = DDPMScheduler.from_config(
|
||||
ddpm = DDPMScheduler.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch",
|
||||
subfolder="scheduler",
|
||||
predict_epsilon=False,
|
||||
@@ -304,7 +199,7 @@ class ConfigTester(unittest.TestCase):
|
||||
)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger_2:
|
||||
ddpm_2 = DDPMScheduler.from_config("google/ddpm-celebahq-256", beta_start=88)
|
||||
ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88)
|
||||
|
||||
assert ddpm.__class__ == DDPMScheduler
|
||||
assert ddpm.config.predict_epsilon is False
|
||||
@@ -319,7 +214,7 @@ class ConfigTester(unittest.TestCase):
|
||||
logger = logging.get_logger("diffusers.configuration_utils")
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
dpm = DPMSolverMultistepScheduler.from_config(
|
||||
dpm = DPMSolverMultistepScheduler.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
|
||||
)
|
||||
|
||||
|
||||
@@ -130,7 +130,7 @@ class ModelTesterMixin:
|
||||
expected_arg_names = ["sample", "timestep"]
|
||||
self.assertListEqual(arg_names[:2], expected_arg_names)
|
||||
|
||||
def test_model_from_config(self):
|
||||
def test_model_from_pretrained(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
@@ -140,8 +140,8 @@ class ModelTesterMixin:
|
||||
# test if the model can be loaded from the config
|
||||
# and has all the expected shape
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_config(tmpdirname)
|
||||
new_model = self.model_class.from_config(tmpdirname)
|
||||
model.save_pretrained(tmpdirname)
|
||||
new_model = self.model_class.from_pretrained(tmpdirname)
|
||||
new_model.to(torch_device)
|
||||
new_model.eval()
|
||||
|
||||
|
||||
@@ -29,6 +29,10 @@ from diffusers import (
|
||||
DDIMScheduler,
|
||||
DDPMPipeline,
|
||||
DDPMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipelineLegacy,
|
||||
@@ -398,6 +402,82 @@ class PipelineFastTests(unittest.TestCase):
|
||||
assert image_img2img.shape == (1, 32, 32, 3)
|
||||
assert image_text2img.shape == (1, 128, 128, 3)
|
||||
|
||||
def test_set_scheduler(self):
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
sd = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
|
||||
sd.scheduler = DDIMScheduler.from_config(sd.scheduler.config)
|
||||
assert isinstance(sd.scheduler, DDIMScheduler)
|
||||
sd.scheduler = DDPMScheduler.from_config(sd.scheduler.config)
|
||||
assert isinstance(sd.scheduler, DDPMScheduler)
|
||||
sd.scheduler = PNDMScheduler.from_config(sd.scheduler.config)
|
||||
assert isinstance(sd.scheduler, PNDMScheduler)
|
||||
sd.scheduler = LMSDiscreteScheduler.from_config(sd.scheduler.config)
|
||||
assert isinstance(sd.scheduler, LMSDiscreteScheduler)
|
||||
sd.scheduler = EulerDiscreteScheduler.from_config(sd.scheduler.config)
|
||||
assert isinstance(sd.scheduler, EulerDiscreteScheduler)
|
||||
sd.scheduler = EulerAncestralDiscreteScheduler.from_config(sd.scheduler.config)
|
||||
assert isinstance(sd.scheduler, EulerAncestralDiscreteScheduler)
|
||||
sd.scheduler = DPMSolverMultistepScheduler.from_config(sd.scheduler.config)
|
||||
assert isinstance(sd.scheduler, DPMSolverMultistepScheduler)
|
||||
|
||||
def test_set_scheduler_consistency(self):
|
||||
unet = self.dummy_cond_unet
|
||||
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
|
||||
ddim = DDIMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
sd = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=pndm,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
|
||||
pndm_config = sd.scheduler.config
|
||||
sd.scheduler = DDPMScheduler.from_config(pndm_config)
|
||||
sd.scheduler = PNDMScheduler.from_config(sd.scheduler.config)
|
||||
pndm_config_2 = sd.scheduler.config
|
||||
pndm_config_2 = {k: v for k, v in pndm_config_2.items() if k in pndm_config}
|
||||
|
||||
assert dict(pndm_config) == dict(pndm_config_2)
|
||||
|
||||
sd = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=ddim,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
|
||||
ddim_config = sd.scheduler.config
|
||||
sd.scheduler = LMSDiscreteScheduler.from_config(ddim_config)
|
||||
sd.scheduler = DDIMScheduler.from_config(sd.scheduler.config)
|
||||
ddim_config_2 = sd.scheduler.config
|
||||
ddim_config_2 = {k: v for k, v in ddim_config_2.items() if k in ddim_config}
|
||||
|
||||
assert dict(ddim_config) == dict(ddim_config_2)
|
||||
|
||||
|
||||
@slow
|
||||
class PipelineSlowTests(unittest.TestCase):
|
||||
@@ -519,7 +599,7 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
def test_output_format(self):
|
||||
model_path = "google/ddpm-cifar10-32"
|
||||
|
||||
scheduler = DDIMScheduler.from_config(model_path)
|
||||
scheduler = DDIMScheduler.from_pretrained(model_path)
|
||||
pipe = DDIMPipeline.from_pretrained(model_path, scheduler=scheduler)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Dict, List, Tuple
|
||||
@@ -21,6 +23,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import diffusers
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
@@ -32,13 +35,180 @@ from diffusers import (
|
||||
PNDMScheduler,
|
||||
ScoreSdeVeScheduler,
|
||||
VQDiffusionScheduler,
|
||||
logging,
|
||||
)
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from diffusers.utils import deprecate, torch_device
|
||||
from diffusers.utils.testing_utils import CaptureLogger
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class SchedulerObject(SchedulerMixin, ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
a=2,
|
||||
b=5,
|
||||
c=(2, 5),
|
||||
d="for diffusion",
|
||||
e=[1, 3],
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class SchedulerObject2(SchedulerMixin, ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
a=2,
|
||||
b=5,
|
||||
c=(2, 5),
|
||||
d="for diffusion",
|
||||
f=[1, 3],
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class SchedulerObject3(SchedulerMixin, ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
a=2,
|
||||
b=5,
|
||||
c=(2, 5),
|
||||
d="for diffusion",
|
||||
e=[1, 3],
|
||||
f=[1, 3],
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class SchedulerBaseTests(unittest.TestCase):
|
||||
def test_save_load_from_different_config(self):
|
||||
obj = SchedulerObject()
|
||||
|
||||
# mock add obj class to `diffusers`
|
||||
setattr(diffusers, "SchedulerObject", SchedulerObject)
|
||||
logger = logging.get_logger("diffusers.configuration_utils")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
obj.save_config(tmpdirname)
|
||||
with CaptureLogger(logger) as cap_logger_1:
|
||||
config = SchedulerObject2.load_config(tmpdirname)
|
||||
new_obj_1 = SchedulerObject2.from_config(config)
|
||||
|
||||
# now save a config parameter that is not expected
|
||||
with open(os.path.join(tmpdirname, SchedulerObject.config_name), "r") as f:
|
||||
data = json.load(f)
|
||||
data["unexpected"] = True
|
||||
|
||||
with open(os.path.join(tmpdirname, SchedulerObject.config_name), "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger_2:
|
||||
config = SchedulerObject.load_config(tmpdirname)
|
||||
new_obj_2 = SchedulerObject.from_config(config)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger_3:
|
||||
config = SchedulerObject2.load_config(tmpdirname)
|
||||
new_obj_3 = SchedulerObject2.from_config(config)
|
||||
|
||||
assert new_obj_1.__class__ == SchedulerObject2
|
||||
assert new_obj_2.__class__ == SchedulerObject
|
||||
assert new_obj_3.__class__ == SchedulerObject2
|
||||
|
||||
assert cap_logger_1.out == ""
|
||||
assert (
|
||||
cap_logger_2.out
|
||||
== "The config attributes {'unexpected': True} were passed to SchedulerObject, but are not expected and"
|
||||
" will"
|
||||
" be ignored. Please verify your config.json configuration file.\n"
|
||||
)
|
||||
assert cap_logger_2.out.replace("SchedulerObject", "SchedulerObject2") == cap_logger_3.out
|
||||
|
||||
def test_save_load_compatible_schedulers(self):
|
||||
SchedulerObject2._compatibles = ["SchedulerObject"]
|
||||
SchedulerObject._compatibles = ["SchedulerObject2"]
|
||||
|
||||
obj = SchedulerObject()
|
||||
|
||||
# mock add obj class to `diffusers`
|
||||
setattr(diffusers, "SchedulerObject", SchedulerObject)
|
||||
setattr(diffusers, "SchedulerObject2", SchedulerObject2)
|
||||
logger = logging.get_logger("diffusers.configuration_utils")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
obj.save_config(tmpdirname)
|
||||
|
||||
# now save a config parameter that is expected by another class, but not origin class
|
||||
with open(os.path.join(tmpdirname, SchedulerObject.config_name), "r") as f:
|
||||
data = json.load(f)
|
||||
data["f"] = [0, 0]
|
||||
data["unexpected"] = True
|
||||
|
||||
with open(os.path.join(tmpdirname, SchedulerObject.config_name), "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
config = SchedulerObject.load_config(tmpdirname)
|
||||
new_obj = SchedulerObject.from_config(config)
|
||||
|
||||
assert new_obj.__class__ == SchedulerObject
|
||||
|
||||
assert (
|
||||
cap_logger.out
|
||||
== "The config attributes {'unexpected': True} were passed to SchedulerObject, but are not expected and"
|
||||
" will"
|
||||
" be ignored. Please verify your config.json configuration file.\n"
|
||||
)
|
||||
|
||||
def test_save_load_from_different_config_comp_schedulers(self):
|
||||
SchedulerObject3._compatibles = ["SchedulerObject", "SchedulerObject2"]
|
||||
SchedulerObject2._compatibles = ["SchedulerObject", "SchedulerObject3"]
|
||||
SchedulerObject._compatibles = ["SchedulerObject2", "SchedulerObject3"]
|
||||
|
||||
obj = SchedulerObject()
|
||||
|
||||
# mock add obj class to `diffusers`
|
||||
setattr(diffusers, "SchedulerObject", SchedulerObject)
|
||||
setattr(diffusers, "SchedulerObject2", SchedulerObject2)
|
||||
setattr(diffusers, "SchedulerObject3", SchedulerObject3)
|
||||
logger = logging.get_logger("diffusers.configuration_utils")
|
||||
logger.setLevel(diffusers.logging.INFO)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
obj.save_config(tmpdirname)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger_1:
|
||||
config = SchedulerObject.load_config(tmpdirname)
|
||||
new_obj_1 = SchedulerObject.from_config(config)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger_2:
|
||||
config = SchedulerObject2.load_config(tmpdirname)
|
||||
new_obj_2 = SchedulerObject2.from_config(config)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger_3:
|
||||
config = SchedulerObject3.load_config(tmpdirname)
|
||||
new_obj_3 = SchedulerObject3.from_config(config)
|
||||
|
||||
assert new_obj_1.__class__ == SchedulerObject
|
||||
assert new_obj_2.__class__ == SchedulerObject2
|
||||
assert new_obj_3.__class__ == SchedulerObject3
|
||||
|
||||
assert cap_logger_1.out == ""
|
||||
assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
|
||||
assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
|
||||
|
||||
|
||||
class SchedulerCommonTest(unittest.TestCase):
|
||||
scheduler_classes = ()
|
||||
forward_default_kwargs = ()
|
||||
@@ -102,7 +272,7 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
@@ -145,7 +315,7 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
@@ -187,7 +357,7 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
@@ -205,6 +375,42 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_compatibles(self):
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
assert all(c is not None for c in scheduler.compatibles)
|
||||
|
||||
for comp_scheduler_cls in scheduler.compatibles:
|
||||
comp_scheduler = comp_scheduler_cls.from_config(scheduler.config)
|
||||
assert comp_scheduler is not None
|
||||
|
||||
new_scheduler = scheduler_class.from_config(comp_scheduler.config)
|
||||
|
||||
new_scheduler_config = {k: v for k, v in new_scheduler.config.items() if k in scheduler.config}
|
||||
scheduler_diff = {k: v for k, v in new_scheduler.config.items() if k not in scheduler.config}
|
||||
|
||||
# make sure that configs are essentially identical
|
||||
assert new_scheduler_config == dict(scheduler.config)
|
||||
|
||||
# make sure that only differences are for configs that are not in init
|
||||
init_keys = inspect.signature(scheduler_class.__init__).parameters.keys()
|
||||
assert set(scheduler_diff.keys()).intersection(set(init_keys)) == set()
|
||||
|
||||
def test_from_pretrained(self):
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_pretrained(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
assert scheduler.config == new_scheduler.config
|
||||
|
||||
def test_step_shape(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
@@ -616,7 +822,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
|
||||
@@ -648,7 +854,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
@@ -790,7 +996,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.ets = dummy_past_residuals[:]
|
||||
@@ -825,7 +1031,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
@@ -1043,7 +1249,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
output = scheduler.step_pred(
|
||||
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
|
||||
@@ -1074,7 +1280,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
output = scheduler.step_pred(
|
||||
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
|
||||
@@ -1470,7 +1676,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.ets = dummy_past_residuals[:]
|
||||
@@ -1508,7 +1714,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
|
||||
@@ -83,7 +83,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
@@ -112,7 +112,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
@@ -140,7 +140,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
@@ -373,7 +373,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
@@ -401,7 +401,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
@@ -430,7 +430,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
state = scheduler.set_timesteps(state, num_inference_steps)
|
||||
@@ -633,7 +633,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
|
||||
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)
|
||||
# copy over dummy past residuals
|
||||
new_state = new_state.replace(ets=dummy_past_residuals[:])
|
||||
@@ -720,7 +720,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
|
||||
# copy over dummy past residuals
|
||||
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user