1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Merge branch 'main' of https://github.com/huggingface/diffusers into stable_diff_opti

This commit is contained in:
Nouamane Tazi
2022-09-19 08:00:09 +00:00
119 changed files with 7033 additions and 918 deletions

27
.github/workflows/stale.yml vendored Normal file
View File

@@ -0,0 +1,27 @@
name: Stale Bot
on:
schedule:
- cron: "0 15 * * *"
jobs:
close_stale_issues:
name: Close Stale Issues
if: github.repository == 'huggingface/diffusers'
runs-on: ubuntu-latest
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
steps:
- uses: actions/checkout@v2
- name: Setup Python
uses: actions/setup-python@v1
with:
python-version: 3.7
- name: Install requirements
run: |
pip install PyGithub
- name: Close stale issues
run: |
python utils/stale.py

14
.github/workflows/typos.yml vendored Normal file
View File

@@ -0,0 +1,14 @@
name: Check typos
on:
workflow_dispatch:
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: typos-action
uses: crate-ci/typos@v1.12.4

View File

@@ -21,7 +21,7 @@ as a modular toolbox for inference and training of diffusion models.
More precisely, 🤗 Diffusers offers:
- State-of-the-art diffusion pipelines that can be run in inference with just a couple of lines of code (see [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines)). Check [this overview](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/README.md#pipelines-summary) to see all supported pipelines and their corresponding official papers.
- Various noise schedulers that can be used interchangeably for the prefered speed vs. quality trade-off in inference (see [src/diffusers/schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers)).
- Various noise schedulers that can be used interchangeably for the preferred speed vs. quality trade-off in inference (see [src/diffusers/schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers)).
- Multiple types of models, such as UNet, can be used as building blocks in an end-to-end diffusion system (see [src/diffusers/models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models)).
- Training examples to show how to train the most popular diffusion model tasks (see [examples](https://github.com/huggingface/diffusers/tree/main/examples), *e.g.* [unconditional-image-generation](https://github.com/huggingface/diffusers/tree/main/examples/unconditional_image_generation)).
@@ -30,7 +30,7 @@ More precisely, 🤗 Diffusers offers:
**With `pip`**
```bash
pip install --upgrade diffusers # should install diffusers 0.2.4
pip install --upgrade diffusers
```
**With `conda`**
@@ -39,6 +39,10 @@ pip install --upgrade diffusers # should install diffusers 0.2.4
conda install -c conda-forge diffusers
```
**Apple Silicon (M1/M2) support**
Please, refer to [the documentation](https://huggingface.co/docs/diffusers/optimization/mps).
## Contributing
We ❤️ contributions from the open-source community!
@@ -191,7 +195,7 @@ with autocast("cuda"):
images[0].save("fantasy_landscape.png")
```
You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb)
You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
### In-painting using Stable Diffusion
@@ -254,42 +258,49 @@ If you want to run the code yourself 💻, you can try out:
- [Text-to-Image Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256)
```python
# !pip install diffusers transformers
from torch import autocast
from diffusers import DiffusionPipeline
device = "cuda"
model_id = "CompVis/ldm-text2im-large-256"
# load model and scheduler
ldm = DiffusionPipeline.from_pretrained(model_id)
ldm = ldm.to(device)
# run pipeline in inference (sample random noise and denoise)
prompt = "A painting of a squirrel eating a burger"
images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images
with autocast(device):
image = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images[0]
# save images
for idx, image in enumerate(images):
image.save(f"squirrel-{idx}.png")
# save image
image.save("squirrel.png")
```
- [Unconditional Diffusion with discrete scheduler](https://huggingface.co/google/ddpm-celebahq-256)
```python
# !pip install diffusers
from torch import autocast
from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline
model_id = "google/ddpm-celebahq-256"
device = "cuda"
# load model and scheduler
ddpm = DDPMPipeline.from_pretrained(model_id) # you can replace DDPMPipeline with DDIMPipeline or PNDMPipeline for faster inference
ddpm.to(device)
# run pipeline in inference (sample random noise and denoise)
image = ddpm().images
with autocast("cuda"):
image = ddpm().images[0]
# save image
image[0].save("ddpm_generated_image.png")
image.save("ddpm_generated_image.png")
```
- [Unconditional Latent Diffusion](https://huggingface.co/CompVis/ldm-celebahq-256)
- [Unconditional Diffusion with continous scheduler](https://huggingface.co/google/ncsnpp-ffhq-1024)
- [Unconditional Diffusion with continuous scheduler](https://huggingface.co/google/ncsnpp-ffhq-1024)
**Other Notebooks**:
* [image-to-image generation with Stable Diffusion](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg),
* [image-to-image generation with Stable Diffusion](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg),
* [tweak images via repeated Stable Diffusion seeds](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg),
### Web Demos
@@ -335,8 +346,8 @@ The class provides functionality to compute previous image according to alpha, b
## Philosophy
- Readability and clarity is prefered over highly optimized code. A strong importance is put on providing readable, intuitive and elementary code design. *E.g.*, the provided [schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers) are separated from the provided [models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and provide well-commented code that can be read alongside the original paper.
- Diffusers is **modality independent** and focuses on providing pretrained models and tools to build systems that generate **continous outputs**, *e.g.* vision and audio.
- Readability and clarity is preferred over highly optimized code. A strong importance is put on providing readable, intuitive and elementary code design. *E.g.*, the provided [schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers) are separated from the provided [models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and provide well-commented code that can be read alongside the original paper.
- Diffusers is **modality independent** and focuses on providing pretrained models and tools to build systems that generate **continuous outputs**, *e.g.* vision and audio.
- Diffusion models and schedulers are provided as concise, elementary building blocks. In contrast, diffusion pipelines are a collection of end-to-end diffusion systems that can be used out-of-the-box, should stay as close as possible to their original implementation and can include components of another library, such as text-encoders. Examples for diffusion pipelines are [Glide](https://github.com/openai/glide-text2im) and [Latent Diffusion](https://github.com/CompVis/latent-diffusion).
## In the works

12
_typos.toml Normal file
View File

@@ -0,0 +1,12 @@
# Files for typos
# Instruction: https://github.com/marketplace/actions/typos-action#getting-started
[default.extend-identifiers]
[default.extend-words]
NIN_="NIN" # NIN is used in scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py
nd="np" # nd may be np (numpy)
[files]
extend-exclude = ["_typos.toml"]

View File

@@ -35,8 +35,6 @@
title: "Open Vino"
- local: optimization/mps
title: "MPS"
- local: optimization/other
title: "Other"
title: "Optimization/Special Hardware"
- sections:
- local: training/overview

View File

@@ -10,19 +10,14 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
# Models
# Configuration
Diffusers contains pretrained models for popular algorithms and modules for creating the next set of diffusion models.
The primary function of these models is to denoise an input sample, by modeling the distribution $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$.
The models are built on the base class ['ModelMixin'] that is a `torch.nn.module` with basic functionality for saving and loading models both locally and from the HuggingFace hub.
In Diffusers, schedulers of type [`schedulers.scheduling_utils.SchedulerMixin`], and models of type [`ModelMixin`] inherit from [`ConfigMixin`] which conveniently takes care of storing all parameters that are
passed to the respective `__init__` methods in a JSON-configuration file.
## API
TODO(PVP) - add example and better info here
Models should provide the `def forward` function and initialization of the model.
All saving, loading, and utilities should be in the base ['ModelMixin'] class.
## Examples
- The ['UNetModel'] was proposed in [TODO](https://arxiv.org/) and has been used in paper1, paper2, paper3.
- Extensions of the ['UNetModel'] include the ['UNetGlideModel'] that uses attention and timestep embeddings for the [GLIDE](https://arxiv.org/abs/2112.10741) paper, the ['UNetGradTTS'] model from this [paper](https://arxiv.org/abs/2105.06337) for text-to-speech, ['UNetLDMModel'] for latent-diffusion models in this [paper](https://arxiv.org/abs/2112.10752), and the ['TemporalUNet'] used for time-series prediciton in this reinforcement learning [paper](https://arxiv.org/abs/2205.09991).
- TODO: mention VAE / SDE score estimation
## ConfigMixin
[[autodoc]] ConfigMixin
- from_config
- save_config

View File

@@ -10,19 +10,30 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
# Models
# Pipelines
Diffusers contains pretrained models for popular algorithms and modules for creating the next set of diffusion models.
The primary function of these models is to denoise an input sample, by modeling the distribution $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$.
The models are built on the base class ['ModelMixin'] that is a `torch.nn.module` with basic functionality for saving and loading models both locally and from the HuggingFace hub.
The [`DiffusionPipeline`] is the easiest way to load any pretrained diffusion pipeline from the [Hub](https://huggingface.co/models?library=diffusers) and to use it in inference.
## API
<Tip>
One should not use the Diffusion Pipeline class for training or fine-tuning a diffusion model. Individual
components of diffusion pipelines are usually trained individually, so we suggest to directly work
with [`UNetModel`] and [`UNetConditionModel`].
Models should provide the `def forward` function and initialization of the model.
All saving, loading, and utilities should be in the base ['ModelMixin'] class.
</Tip>
## Examples
Any diffusion pipeline that is loaded with [`~DiffusionPipeline.from_pretrained`] will automatically
detect the pipeline type, *e.g.* [`StableDiffusionPipeline`] and consequently load each component of the
pipeline and pass them into the `__init__` function of the pipeline, *e.g.* [`~StableDiffusionPipeline.__init__`].
- The ['UNetModel'] was proposed in [TODO](https://arxiv.org/) and has been used in paper1, paper2, paper3.
- Extensions of the ['UNetModel'] include the ['UNetGlideModel'] that uses attention and timestep embeddings for the [GLIDE](https://arxiv.org/abs/2112.10741) paper, the ['UNetGradTTS'] model from this [paper](https://arxiv.org/abs/2105.06337) for text-to-speech, ['UNetLDMModel'] for latent-diffusion models in this [paper](https://arxiv.org/abs/2112.10752), and the ['TemporalUNet'] used for time-series prediciton in this reinforcement learning [paper](https://arxiv.org/abs/2205.09991).
- TODO: mention VAE / SDE score estimation
Any pipeline object can be saved locally with [`~DiffusionPipeline.save_pretrained`].
## DiffusionPipeline
[[autodoc]] DiffusionPipeline
- from_pretrained
- save_pretrained
## ImagePipelineOutput
By default diffusion pipelines return an object of class
[[autodoc]] pipeline_utils.ImagePipelineOutput

View File

@@ -16,13 +16,32 @@ Diffusers contains pretrained models for popular algorithms and modules for crea
The primary function of these models is to denoise an input sample, by modeling the distribution $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$.
The models are built on the base class ['ModelMixin'] that is a `torch.nn.module` with basic functionality for saving and loading models both locally and from the HuggingFace hub.
## API
## ModelMixin
[[autodoc]] ModelMixin
Models should provide the `def forward` function and initialization of the model.
All saving, loading, and utilities should be in the base ['ModelMixin'] class.
## UNet2DOutput
[[autodoc]] models.unet_2d.UNet2DOutput
## Examples
## UNet2DModel
[[autodoc]] UNet2DModel
- The ['UNetModel'] was proposed in [TODO](https://arxiv.org/) and has been used in paper1, paper2, paper3.
- Extensions of the ['UNetModel'] include the ['UNetGlideModel'] that uses attention and timestep embeddings for the [GLIDE](https://arxiv.org/abs/2112.10741) paper, the ['UNetGradTTS'] model from this [paper](https://arxiv.org/abs/2105.06337) for text-to-speech, ['UNetLDMModel'] for latent-diffusion models in this [paper](https://arxiv.org/abs/2112.10752), and the ['TemporalUNet'] used for time-series prediciton in this reinforcement learning [paper](https://arxiv.org/abs/2205.09991).
- TODO: mention VAE / SDE score estimation
## UNet2DConditionOutput
[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput
## UNet2DConditionModel
[[autodoc]] UNet2DConditionModel
## DecoderOutput
[[autodoc]] models.vae.DecoderOutput
## VQEncoderOutput
[[autodoc]] models.vae.VQEncoderOutput
## VQModel
[[autodoc]] VQModel
## AutoencoderKLOutput
[[autodoc]] models.vae.AutoencoderKLOutput
## AutoencoderKL
[[autodoc]] AutoencoderKL

View File

@@ -10,19 +10,46 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
# Models
# BaseOutputs
Diffusers contains pretrained models for popular algorithms and modules for creating the next set of diffusion models.
The primary function of these models is to denoise an input sample, by modeling the distribution $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$.
The models are built on the base class ['ModelMixin'] that is a `torch.nn.module` with basic functionality for saving and loading models both locally and from the HuggingFace hub.
All models have outputs that are instances of subclasses of [`~utils.BaseOutput`]. Those are
data structures containing all the information returned by the model, but that can also be used as tuples or
dictionaries.
## API
Let's see how this looks in an example:
Models should provide the `def forward` function and initialization of the model.
All saving, loading, and utilities should be in the base ['ModelMixin'] class.
```python
from diffusers import DDIMPipeline
## Examples
pipeline = DDIMPipeline.from_pretrained("google/ddpm-cifar10-32")
outputs = pipeline()
```
- The ['UNetModel'] was proposed in [TODO](https://arxiv.org/) and has been used in paper1, paper2, paper3.
- Extensions of the ['UNetModel'] include the ['UNetGlideModel'] that uses attention and timestep embeddings for the [GLIDE](https://arxiv.org/abs/2112.10741) paper, the ['UNetGradTTS'] model from this [paper](https://arxiv.org/abs/2105.06337) for text-to-speech, ['UNetLDMModel'] for latent-diffusion models in this [paper](https://arxiv.org/abs/2112.10752), and the ['TemporalUNet'] used for time-series prediciton in this reinforcement learning [paper](https://arxiv.org/abs/2205.09991).
- TODO: mention VAE / SDE score estimation
The `outputs` object is a [`~pipeline_utils.ImagePipelineOutput`], as we can see in the
documentation of that class below, it means it has an image attribute.
You can access each attribute as you would usually do, and if that attribute has not been returned by the model, you will get `None`:
```python
outputs.images
```
or via keyword lookup
```python
outputs["images"]
```
When considering our `outputs` object as tuple, it only considers the attributes that don't have `None` values.
Here for instance, we could retrieve images via indexing:
```python
outputs[:1]
```
which will return the tuple `(outputs.images)` for instance.
## BaseOutput
[[autodoc]] utils.BaseOutput
- to_tuple

View File

@@ -17,7 +17,6 @@ The original codebase of this paper can be found [here](https://github.com/ermon
| [pipeline_ddim.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddim/pipeline_ddim.py) | *Unconditional Image Generation* | - |
## API
[[autodoc]] pipelines.ddim.pipeline_ddim.DDIMPipeline
## DDIMPipeline
[[autodoc]] DDIMPipeline
- __call__

View File

@@ -19,7 +19,6 @@ The original codebase of this paper can be found [here](https://github.com/hojon
| [pipeline_ddpm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddpm/pipeline_ddpm.py) | *Unconditional Image Generation* | - |
## API
[[autodoc]] pipelines.ddpm.pipeline_ddpm.DDPMPipeline
# DDPMPipeline
[[autodoc]] DDPMPipeline
- __call__

View File

@@ -25,7 +25,6 @@ The original codebase can be found [here](https://github.com/CompVis/latent-diff
## Examples:
## API
## LDMTextToImagePipeline
[[autodoc]] pipelines.latent_diffusion.pipeline_latent_diffusion.LDMTextToImagePipeline
- __call__

View File

@@ -24,7 +24,6 @@ The original codebase can be found [here](https://github.com/CompVis/latent-diff
## Examples:
## API
[[autodoc]] pipelines.latent_diffusion_uncond.pipeline_latent_diffusion_uncond.LDMPipeline
## LDMPipeline
[[autodoc]] LDMPipeline
- __call__

View File

@@ -51,8 +51,8 @@ available a colab notebook to directly try them out.
| [score_sde_ve](./score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [score_sde_vp](./score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stochatic_karras_ve](./stochatic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.
@@ -143,7 +143,7 @@ with autocast("cuda"):
images[0].save("fantasy_landscape.png")
```
You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb)
You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
### Tweak prompts reusing seeds and latents
@@ -187,4 +187,4 @@ with autocast("cuda"):
images[0].save("cat_on_bench.png")
```
You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/in_painting_with_stable_diffusion_using_diffusers.ipynb)
You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)

View File

@@ -17,8 +17,7 @@ The original codebase can be found [here](https://github.com/luping-liu/PNDM).
| [pipeline_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pndm/pipeline_pndm.py) | *Unconditional Image Generation* | - |
## API
## PNDMPipeline
[[autodoc]] pipelines.pndm.pipeline_pndm.PNDMPipeline
- __call__

View File

@@ -18,8 +18,7 @@ This pipeline implements the Variance Expanding (VE) variant of the method.
|---|---|:---:|
| [pipeline_score_sde_ve.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py) | *Unconditional Image Generation* | - |
## API
[[autodoc]] pipelines.score_sde_ve.pipeline_score_sde_ve.ScoreSdeVePipeline
## ScoreSdeVePipeline
[[autodoc]] ScoreSdeVePipeline
- __call__

View File

@@ -1,40 +1,39 @@
# Stable diffusion pipelines
## Overview
Stable Diffusion is a text-to-image _latent diffusion_ model created by the researchers and engineers from [CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/) and [LAION](https://laion.ai/). It's trained on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) dataset. This model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts. With its 860M UNet and 123M text encoder, the model is relatively lightweight and can run on consumer GPUs.
Latent diffusion is the research on top of which Stable Diffusion was built. It was proposed in [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer. You can learn more details about it in the [specific pipeline for latent diffusion](pipelines/latent_diffusion) that is part of 🤗 Diffusers.
For more details about how Stable Diffusion works and how it differs from the base latent diffusion model, please refer to the official [launch announcement post](https://stability.ai/blog/stable-diffusion-announcement) and [this section of our own blog post](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work).
## Tips
*Tips*:
- To tweak your prompts on a specific result you liked, you can generate your own latents, as demonstrated in the following notebook: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb)
- TODO: some interesting Tips
## Available pipelines
*Overview*:
| Pipeline | Tasks | Colab | Demo
|---|---|:---:|:---:|
| [pipeline_stable_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb) | [🤗 Stable Diffusion](https://huggingface.co/spaces/stabilityai/stable-diffusion)
| [pipeline_stable_diffusion_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb) | [🤗 Diffuse the Rest](https://huggingface.co/spaces/huggingface/diffuse-the-rest)
| [pipeline_stable_diffusion_inpaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | **Experimental** *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/in_painting_with_stable_diffusion_using_diffusers.ipynb) | Coming soon
|---|---|:---:|:---:|
| [pipeline_stable_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb) | [🤗 Stable Diffusion](https://huggingface.co/spaces/stabilityai/stable-diffusion)
| [pipeline_stable_diffusion_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) | [🤗 Diffuse the Rest](https://huggingface.co/spaces/huggingface/diffuse-the-rest)
| [pipeline_stable_diffusion_inpaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | **Experimental** *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) | Coming soon
## API
## StableDiffusionPipelineOutput
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
## StableDiffusionPipeline
[[autodoc]] StableDiffusionPipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing
- __call__
- enable_attention_slicing
- disable_attention_slicing
## StableDiffusionImg2ImgPipeline
[[autodoc]] StableDiffusionImg2ImgPipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing
- __call__
- enable_attention_slicing
- disable_attention_slicing
## StableDiffusionInpaintPipeline
[[autodoc]] StableDiffusionInpaintPipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing
- __call__
- enable_attention_slicing
- disable_attention_slicing

View File

@@ -18,7 +18,6 @@ This pipeline implements the Stochastic sampling tailored to the Variance-Expand
| [pipeline_stochastic_karras_ve.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py) | *Unconditional Image Generation* | - |
## API
[[autodoc]] pipelines.stochastic_karras_ve.pipeline_stochastic_karras_ve.KarrasVePipeline
## KarrasVePipeline
[[autodoc]] KarrasVePipeline
- __call__

View File

@@ -14,7 +14,8 @@ specific language governing permissions and limitations under the License.
Diffusers contains multiple pre-built schedule functions for the diffusion process.
## What is a schduler?
## What is a scheduler?
The schedule functions, denoted *Schedulers* in the library take in the output of a trained model, a sample which the diffusion process is iterating on, and a timestep to return a denoised sample.
- Schedulers define the methodology for iteratively adding noise to an image or for updating a sample based on model outputs.
@@ -23,73 +24,77 @@ The schedule functions, denoted *Schedulers* in the library take in the output o
- Schedulers are often defined by a *noise schedule* and an *update rule* to solve the differential equation solution.
### Discrete versus continuous schedulers
All schedulers take in a timestep to predict the updated version of the sample being diffused.
The timesteps dictate where in the diffusion process the step is, where data is generated by iterating forward in time and inference is executed by propagating backwards through timesteps.
Different algorithms use timesteps that both discrete (accepting `int` inputs), such as the [`DDPMScheduler`] or [`PNDMScheduler`], and continuous (accepting 'float` inputs), such as the score-based schedulers [`ScoreSdeVeScheduler`] or [`ScoreSdeVpScheduler`].
Different algorithms use timesteps that both discrete (accepting `int` inputs), such as the [`DDPMScheduler`] or [`PNDMScheduler`], and continuous (accepting `float` inputs), such as the score-based schedulers [`ScoreSdeVeScheduler`] or [`ScoreSdeVpScheduler`].
## Designing Re-usable schedulers
The core design principle between the schedule functions is to be model, system, and framework independent.
This allows for rapid experimentation and cleaner abstractions in the code, where the model prediction is separated from the sample update.
To this end, the design of schedulers is such that:
- Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality.
- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Numpy support currently exists).
## API
The core API for any new scheduler must follow a limited structure.
- Schedulers should provide one or more `def step(...)` functions that should be called to update the generated sample iteratively.
- Schedulers should provide a `set_timesteps(...)` method that configures the parameters of a schedule function for a specific inference task.
- Schedulers should be framework-agonstic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch
- Schedulers should be framework-agnostic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch
with a `set_format(...)` method.
### Core
The base class [`SchedulerMixin`] implements low level utilities used by multiple schedulers.
#### SchedulerMixin
### SchedulerMixin
[[autodoc]] SchedulerMixin
#### SchedulerOutput
The class [`SchedulerOutput`] contains the ouputs from any schedulers `step(...)` call.
### SchedulerOutput
The class [`SchedulerOutput`] contains the outputs from any schedulers `step(...)` call.
[[autodoc]] schedulers.scheduling_utils.SchedulerOutput
### Existing Schedulers
### Implemented Schedulers
#### Denoising diffusion implicit models (DDIM)
Original paper can be found here.
[[autodoc]] schedulers.scheduling_ddim.DDIMScheduler
[[autodoc]] DDIMScheduler
#### Denoising diffusion probabilistic models (DDPM)
Original paper can be found [here](https://arxiv.org/abs/2010.02502).
[[autodoc]] schedulers.scheduling_ddpm.DDPMScheduler
[[autodoc]] DDPMScheduler
#### Varience exploding, stochastic sampling from Karras et. al
#### Variance exploding, stochastic sampling from Karras et. al
Original paper can be found [here](https://arxiv.org/abs/2006.11239).
[[autodoc]] schedulers.scheduling_karras_ve.KarrasVeScheduler
[[autodoc]] KarrasVeScheduler
#### Linear multistep scheduler for discrete beta schedules
Original implementation can be found [here](https://arxiv.org/abs/2206.00364).
[[autodoc]] schedulers.scheduling_lms_discrete.LMSDiscreteScheduler
[[autodoc]] LMSDiscreteScheduler
#### Pseudo numerical methods for diffusion models (PNDM)
Original implementation can be found [here](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181).
[[autodoc]] schedulers.scheduling_pndm.PNDMScheduler
[[autodoc]] PNDMScheduler
#### variance exploding stochastic differential equation (SDE) scheduler
Original paper can be found [here](https://arxiv.org/abs/2011.13456).
[[autodoc]] schedulers.scheduling_sde_ve.ScoreSdeVeScheduler
[[autodoc]] ScoreSdeVeScheduler
#### variance preserving stochastic differential equation (SDE) scheduler

View File

@@ -10,23 +10,8 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
# Stable Diffusion
Under construction 🚧
# Quicktour
Start using Diffusers🧨 quickly!
To start, use the [`DiffusionPipeline`] for quick inference and sample generations!
```
pip install diffusers
```
## Main classes
### Models
### Schedulers
### Pipeliens
For now please visit this [very in-detail blog post](https://huggingface.co/blog/stable_diffusion)

View File

@@ -42,8 +42,8 @@ available a colab notebook to directly try them out.
| [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [score_sde_vp](./api/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stochatic_karras_ve](./api/pipelines/stochatic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.

View File

@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
Install Diffusers for with PyTorch. Support for other libraries will come in the future
🤗 Diffusers is tested on Python 3.6+, and PyTorch 1.4.0+.
🤗 Diffusers is tested on Python 3.7+, and PyTorch 1.7.0+.
## Install with pip
@@ -66,7 +66,7 @@ Clone the repository and install 🤗 Diffusers with the following commands:
```bash
git clone https://github.com/huggingface/diffusers.git
cd transformers
cd diffusers
pip install -e .
```

View File

@@ -11,22 +11,33 @@ specific language governing permissions and limitations under the License.
-->
# How to use the ONNX Runtime for inference
# Quicktour
🤗 Diffusers provides a Stable Diffusion pipeline compatible with the ONNX Runtime. This allows you to run Stable Diffusion on any hardware that supports ONNX (including CPUs), and where an accelerated version of PyTorch is not available.
Start using Diffusers🧨 quickly!
To start, use the [`DiffusionPipeline`] for quick inference and sample generations!
## Installation
```
pip install diffusers
- TODO
## Stable Diffusion Inference
The snippet below demonstrates how to use the ONNX runtime. You need to use `StableDiffusionOnnxPipeline` instead of `StableDiffusionPipeline`. You also need to download the weights from the `onnx` branch of the repository, and indicate the runtime provider you want to use.
```python
# make sure you're logged in with `huggingface-cli login`
from diffusers import StableDiffusionOnnxPipeline
pipe = StableDiffusionOnnxPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="onnx",
provider="CUDAExecutionProvider",
use_auth_token=True,
)
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
```
## Main classes
### Models
### Schedulers
### Pipeliens
## Known Issues
- Generating multiple prompts in a batch seems to take too much memory. While we look into it, you may need to iterate instead of batching.

View File

@@ -10,23 +10,6 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
# OpenVINO
# Quicktour
Start using Diffusers🧨 quickly!
To start, use the [`DiffusionPipeline`] for quick inference and sample generations!
```
pip install diffusers
```
## Main classes
### Models
### Schedulers
### Pipeliens
Under construction 🚧

View File

@@ -1,32 +0,0 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Quicktour
Start using Diffusers🧨 quickly!
To start, use the [`DiffusionPipeline`] for quick inference and sample generations!
```
pip install diffusers
```
## Main classes
### Models
### Schedulers
### Pipeliens

View File

@@ -86,11 +86,11 @@ just like we did before only that now you need to pass your `AUTH_TOKEN`:
>>> generator = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=AUTH_TOKEN)
```
If you do not pass your authentification token you will see that the diffusion system will not be correctly
downloaded. Forcing the user to pass an authentification token ensures that it can be verified that the
If you do not pass your authentication token you will see that the diffusion system will not be correctly
downloaded. Forcing the user to pass an authentication token ensures that it can be verified that the
user has indeed read and accepted the license, which also means that an internet connection is required.
**Note**: If you do not want to be forced to pass an authentification token, you can also simply download
**Note**: If you do not want to be forced to pass an authentication token, you can also simply download
the weights locally via:
```
@@ -98,7 +98,7 @@ git lfs install
git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
```
and then load locally saved weights into the pipeline. This way, you do not need to pass an authentification
and then load locally saved weights into the pipeline. This way, you do not need to pass an authentication
token. Assuming that `"./stable-diffusion-v1-4"` is the local path to the cloned stable-diffusion-v1-4 repo,
you can also load the pipeline as follows:
@@ -114,7 +114,7 @@ Running the pipeline is then identical to the code above as it's the same model
>>> image.save("image_of_squirrel_painting.png")
```
Diffusion systems can be used with multiple different [schedulers]("api/schedulers") each with their
Diffusion systems can be used with multiple different [schedulers](./api/schedulers) each with their
pros and cons. By default, Stable Diffusion runs with [`PNDMScheduler`], but it's very simple to
use a different scheduler. *E.g.* if you would instead like to use the [`LMSDiscreteScheduler`] scheduler,
you could use it as follows:
@@ -131,15 +131,15 @@ you could use it as follows:
[Stability AI's](https://stability.ai/) Stable Diffusion model is an impressive image generation model
and can do much more than just generating images from text. We have dedicated a whole documentation page,
just for Stable Diffusion [here]("./conceptual/stable_diffusion").
just for Stable Diffusion [here](./conceptual/stable_diffusion).
If you want to know how to optimize Stable Diffusion to run on less memory, higher inference speeds, on specific hardware, such as Mac, or with [ONNX Runtime](https://onnxruntime.ai/), please have a look at our
optimization pages:
- [Optimized PyTorch on GPU]("./optimization/fp16")
- [Mac OS with PyTorch]("./optimization/mps")
- [ONNX]("./optimization/onnx)
- [Other clever optimization tricks]("./optimization/other)
- [Optimized PyTorch on GPU](./optimization/fp16)
- [Mac OS with PyTorch](./optimization/mps)
- [ONNX](./optimization/onnx)
- [OpenVINO](./optimization/open_vino)
If you want to fine-tune or train your diffusion model, please have a look at the [**training section**](./training/overview)

View File

@@ -10,10 +10,60 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
# Training
# 🧨 Diffusers Training Examples
You can train diffusion models in multiple ways:
Diffusers examples are a collection of scripts to demonstrate how to effectively use the `diffusers` library
for a variety of use cases.
**Note**: If you are looking for **official** examples on how to use `diffusers` for inference,
please have a look at [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines)
Our examples aspire to be **self-contained**, **easy-to-tweak**, **beginner-friendly** and for **one-purpose-only**.
More specifically, this means:
- **Self-contained**: An example script shall only depend on "pip-install-able" Python packages that can be found in a `requirements.txt` file. Example scripts shall **not** depend on any local files. This means that one can simply download an example script, *e.g.* [train_unconditional.py](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py), install the required dependencies, *e.g.* [requirements.txt](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/requirements.txt) and execute the example script.
- **Easy-to-tweak**: While we strive to present as many use cases as possible, the example scripts are just that - examples. It is expected that they won't work out-of-the box on your specific problem and that you will be required to change a few lines of code to adapt them to your needs. To help you with that, most of the examples fully expose the preprocessing of the data and the training loop to allow you to tweak and edit them as required.
- **Beginner-friendly**: We do not aim for providing state-of-the-art training scripts for the newest models, but rather examples that can be used as a way to better understand diffusion models and how to use them with the `diffusers` library. We often purposefully leave out certain state-of-the-art methods if we consider them too complex for beginners.
- **One-purpose-only**: Examples should show one task and one task only. Even if a task is from a modeling
point of view very similar, *e.g.* image super-resolution and image modification tend to use the same model and training method, we want examples to showcase only one task to keep them as readable and easy-to-understand as possible.
We provide **official** examples that cover the most popular tasks of diffusion models.
*Official* examples are **actively** maintained by the `diffusers` maintainers and we try to rigorously follow our example philosophy as defined above.
If you feel like another important example should exist, we are more than happy to welcome a [Feature Request](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=) or directly a [Pull Request](https://github.com/huggingface/diffusers/compare) from you!
Training examples show how to pretrain or fine-tune diffusion models for a variety of tasks. Currently we support:
- [Unconditional Training](./unconditional_training)
- [Text-to-Image Training](./text2image)
- [Text Inversion](./text_inversion)
| Task | 🤗 Accelerate | 🤗 Datasets | Colab
|---|---|:---:|:---:|
| [**Unconditional Image Generation**](./unconditional_training) | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [**Text-to-Image**](./text2image) | - | - |
| [**Text-Inversion**](./text_inversion) | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
## Community
In addition, we provide **community** examples, which are examples added and maintained by our community.
Community examples can consist of both *training* examples or *inference* pipelines.
For such examples, we are more lenient regarding the philosophy defined above and also cannot guarantee to provide maintenance for every issue.
Examples that are useful for the community, but are either not yet deemed popular or not yet following our above philosophy should go into the [community examples](https://github.com/huggingface/diffusers/tree/main/examples/community) folder. The community folder therefore includes training examples and inference pipelines.
**Note**: Community examples can be a [great first contribution](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) to show to the community how you like to use `diffusers` 🪄.
## Important note
To make sure you can successfully run the latest versions of the example scripts, you have to **install the library from source** and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
```
Then cd in the example folder of your choice and run
```bash
pip install -r requirements.txt
```

View File

@@ -11,22 +11,6 @@ specific language governing permissions and limitations under the License.
-->
# Text-to-Image Training
# Quicktour
Start using Diffusers🧨 quickly!
To start, use the [`DiffusionPipeline`] for quick inference and sample generations!
```
pip install diffusers
```
## Main classes
### Models
### Schedulers
### Pipeliens
Under construction 🚧

View File

@@ -49,7 +49,7 @@ The `textual_inversion.py` script [here](https://github.com/huggingface/diffuser
### Installing the dependencies
Before running the scipts, make sure to install the library's training dependencies:
Before running the scripts, make sure to install the library's training dependencies:
```bash
pip install diffusers[training] accelerate transformers
@@ -68,7 +68,7 @@ You need to accept the model license before downloading or using the weights. In
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
Run the following command to autheticate your token
Run the following command to authenticate your token
```bash
huggingface-cli login

View File

@@ -10,23 +10,140 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
# Unconditional Image-Generation
In this section, we explain how one can train an unconditional image generation diffusion
model. "Unconditional" because the model is not conditioned on any context to generate an image - once trained the model will simply generate images that resemble its training data
distribution.
# Quicktour
## Installing the dependencies
Start using Diffusers🧨 quickly!
To start, use the [`DiffusionPipeline`] for quick inference and sample generations!
Before running the scripts, make sure to install the library's training dependencies:
```
pip install diffusers
```bash
pip install diffusers[training] accelerate datasets
```
## Main classes
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
### Models
```bash
accelerate config
```
### Schedulers
## Unconditional Flowers
### Pipeliens
The command to train a DDPM UNet model on the Oxford Flowers dataset:
```bash
accelerate launch train_unconditional.py \
--dataset_name="huggan/flowers-102-categories" \
--resolution=64 \
--output_dir="ddpm-ema-flowers-64" \
--train_batch_size=16 \
--num_epochs=100 \
--gradient_accumulation_steps=1 \
--learning_rate=1e-4 \
--lr_warmup_steps=500 \
--mixed_precision=no \
--push_to_hub
```
An example trained model: https://huggingface.co/anton-l/ddpm-ema-flowers-64
A full training run takes 2 hours on 4xV100 GPUs.
<img src="https://user-images.githubusercontent.com/26864830/180248660-a0b143d0-b89a-42c5-8656-2ebf6ece7e52.png" width="700" />
## Unconditional Pokemon
The command to train a DDPM UNet model on the Pokemon dataset:
```bash
accelerate launch train_unconditional.py \
--dataset_name="huggan/pokemon" \
--resolution=64 \
--output_dir="ddpm-ema-pokemon-64" \
--train_batch_size=16 \
--num_epochs=100 \
--gradient_accumulation_steps=1 \
--learning_rate=1e-4 \
--lr_warmup_steps=500 \
--mixed_precision=no \
--push_to_hub
```
An example trained model: https://huggingface.co/anton-l/ddpm-ema-pokemon-64
A full training run takes 2 hours on 4xV100 GPUs.
<img src="https://user-images.githubusercontent.com/26864830/180248200-928953b4-db38-48db-b0c6-8b740fe6786f.png" width="700" />
## Using your own data
To use your own dataset, there are 2 ways:
- you can either provide your own folder as `--train_data_dir`
- or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.
**Note**: If you want to create your own training dataset please have a look at [this document](https://huggingface.co/docs/datasets/image_process#image-datasets).
Below, we explain both in more detail.
### Provide the dataset as a folder
If you provide your own folders with images, the script expects the following directory structure:
```bash
data_dir/xxx.png
data_dir/xxy.png
data_dir/[...]/xxz.png
```
In other words, the script will take care of gathering all images inside the folder. You can then run the script like this:
```bash
accelerate launch train_unconditional.py \
--train_data_dir <path-to-train-directory> \
<other-arguments>
```
Internally, the script will use the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature which will automatically turn the folders into 🤗 Dataset objects.
### Upload your data to the hub, as a (possibly private) repo
It's very easy (and convenient) to upload your image dataset to the hub using the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature available in 🤗 Datasets. Simply do the following:
```python
from datasets import load_dataset
# example 1: local folder
dataset = load_dataset("imagefolder", data_dir="path_to_your_folder")
# example 2: local files (supported formats are tar, gzip, zip, xz, rar, zstd)
dataset = load_dataset("imagefolder", data_files="path_to_zip_file")
# example 3: remote files (supported formats are tar, gzip, zip, xz, rar, zstd)
dataset = load_dataset(
"imagefolder",
data_files="https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip",
)
# example 4: providing several splits
dataset = load_dataset(
"imagefolder", data_files={"train": ["path/to/file1", "path/to/file2"], "test": ["path/to/file3", "path/to/file4"]}
)
```
`ImageFolder` will create an `image` column containing the PIL-encoded images.
Next, push it to the hub!
```python
# assuming you have ran the huggingface-cli login command in a terminal
dataset.push_to_hub("name_of_your_dataset")
# if you want to push to a private repo, simply pass private=True:
dataset.push_to_hub("name_of_your_dataset", private=True)
```
and that's it! You can now train your model by simply setting the `--dataset_name` argument to the name of your dataset on the hub.
More on this can also be found in [this blog post](https://huggingface.co/blog/image-search-datasets).

View File

@@ -10,23 +10,39 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
# Conditional Image Generation
The [`DiffusionPipeline`] is the easiest way to use a pre-trained diffusion system for inference
# Quicktour
Start by creating an instance of [`DiffusionPipeline`] and specify which pipeline checkpoint you would like to download.
You can use the [`DiffusionPipeline`] for any [Diffusers' checkpoint](https://huggingface.co/models?library=diffusers&sort=downloads).
In this guide though, you'll use [`DiffusionPipeline`] for text-to-image generation with [Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256):
Start using Diffusers🧨 quickly!
To start, use the [`DiffusionPipeline`] for quick inference and sample generations!
```python
>>> from diffusers import DiffusionPipeline
>>> generator = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
```
pip install diffusers
The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components.
Because the model consists of roughly 1.4 billion parameters, we strongly recommend running it on GPU.
You can move the generator object to GPU, just like you would in PyTorch.
```python
>>> generator.to("cuda")
```
## Main classes
Now you can use the `generator` on your text prompt:
### Models
```python
>>> image = generator("An image of a squirrel in Picasso style").images[0]
```
### Schedulers
The output is by default wrapped into a [PIL Image object](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class).
### Pipeliens
You can save the image by simply calling:
```python
>>> image.save("image_of_squirrel_painting.png")
```

View File

@@ -10,23 +10,6 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
# Custom Pipeline
# Quicktour
Start using Diffusers🧨 quickly!
To start, use the [`DiffusionPipeline`] for quick inference and sample generations!
```
pip install diffusers
```
## Main classes
### Models
### Schedulers
### Pipeliens
Under construction 🚧

View File

@@ -10,23 +10,37 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
# Text-Guided Image-to-Image Generation
The [`StableDiffusionImg2ImgPipeline`] lets you pass a text prompt and an initial image to condition the generation of new images.
# Quicktour
```python
from torch import autocast
import requests
from PIL import Image
from io import BytesIO
Start using Diffusers🧨 quickly!
To start, use the [`DiffusionPipeline`] for quick inference and sample generations!
from diffusers import StableDiffusionImg2ImgPipeline
# load the pipeline
device = "cuda"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=True
).to(device)
# let's download an initial image
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((768, 512))
prompt = "A fantasy landscape, trending on artstation"
with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
images[0].save("fantasy_landscape.png")
```
pip install diffusers
```
## Main classes
### Models
### Schedulers
### Pipeliens
You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)

View File

@@ -10,23 +10,41 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
# Text-Guided Image-Inpainting
The [`StableDiffusionInpaintPipeline`] lets you edit specific parts of an image by providing a mask and text prompt.
```python
from io import BytesIO
from torch import autocast
import requests
import PIL
from diffusers import StableDiffusionInpaintPipeline
# Quicktour
def download_image(url):
response = requests.get(url)
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
Start using Diffusers🧨 quickly!
To start, use the [`DiffusionPipeline`] for quick inference and sample generations!
```
pip install diffusers
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
init_image = download_image(img_url).resize((512, 512))
mask_image = download_image(mask_url).resize((512, 512))
device = "cuda"
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=True
).to(device)
prompt = "a cat sitting on a bench"
with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
images[0].save("cat_on_bench.png")
```
## Main classes
### Models
### Schedulers
### Pipeliens
You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)

View File

@@ -10,23 +10,6 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
# Loading
# Quicktour
Start using Diffusers🧨 quickly!
To start, use the [`DiffusionPipeline`] for quick inference and sample generations!
```
pip install diffusers
```
## Main classes
### Models
### Schedulers
### Pipeliens
Under construction 🚧

View File

@@ -12,21 +12,41 @@ specific language governing permissions and limitations under the License.
# Quicktour
# Unonditional Image Generation
Start using Diffusers🧨 quickly!
To start, use the [`DiffusionPipeline`] for quick inference and sample generations!
The [`DiffusionPipeline`] is the easiest way to use a pre-trained diffusion system for inference
Start by creating an instance of [`DiffusionPipeline`] and specify which pipeline checkpoint you would like to download.
You can use the [`DiffusionPipeline`] for any [Diffusers' checkpoint](https://huggingface.co/models?library=diffusers&sort=downloads).
In this guide though, you'll use [`DiffusionPipeline`] for unconditional image generation with [DDPM](https://arxiv.org/abs/2006.11239):
```python
>>> from diffusers import DiffusionPipeline
>>> generator = DiffusionPipeline.from_pretrained("google/ddpm-celebahq-256")
```
pip install diffusers
The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components.
Because the model consists of roughly 1.4 billion parameters, we strongly recommend running it on GPU.
You can move the generator object to GPU, just like you would in PyTorch.
```python
>>> generator.to("cuda")
```
## Main classes
Now you can use the `generator` on your text prompt:
### Models
```python
>>> image = generator().images[0]
```
The output is by default wrapped into a [PIL Image object](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class).
You can save the image by simply calling:
```python
>>> image.save("generated_image.png")
```
### Schedulers
### Pipeliens

View File

@@ -14,7 +14,7 @@ Colab for inference
## Running locally
### Installing the dependencies
Before running the scipts, make sure to install the library's training dependencies:
Before running the scripts, make sure to install the library's training dependencies:
```bash
pip install diffusers[training] accelerate transformers
@@ -33,7 +33,7 @@ You need to accept the model license before downloading or using the weights. In
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
Run the following command to autheticate your token
Run the following command to authenticate your token
```bash
huggingface-cli login

View File

@@ -238,7 +238,6 @@ class TextualInversionDataset(Dataset):
placeholder_token="*",
center_crop=False,
):
self.data_root = data_root
self.tokenizer = tokenizer
self.learnable_property = learnable_property
@@ -423,7 +422,7 @@ def main():
eps=args.adam_epsilon,
)
# TODO (patil-suraj): laod scheduler using args
# TODO (patil-suraj): load scheduler using args
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt"
)
@@ -505,7 +504,9 @@ def main():
noise = torch.randn(latents.shape).to(latents.device)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long()
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
).long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)

View File

@@ -4,7 +4,7 @@ Creating a training image set is [described in a different document](https://hug
### Installing the dependencies
Before running the scipts, make sure to install the library's training dependencies:
Before running the scripts, make sure to install the library's training dependencies:
```bash
pip install diffusers[training] accelerate datasets
@@ -102,7 +102,7 @@ from datasets import load_dataset
# example 1: local folder
dataset = load_dataset("imagefolder", data_dir="path_to_your_folder")
# example 2: local files (suppoted formats are tar, gzip, zip, xz, rar, zstd)
# example 2: local files (supported formats are tar, gzip, zip, xz, rar, zstd)
dataset = load_dataset("imagefolder", data_files="path_to_zip_file")
# example 3: remote files (supported formats are tar, gzip, zip, xz, rar, zstd)

View File

@@ -130,7 +130,7 @@ def main(args):
bsz = clean_images.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.num_train_timesteps, (bsz,), device=clean_images.device
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device
).long()
# Add noise to the clean images according to the noise magnitude at each timestep

View File

@@ -22,7 +22,7 @@ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
new_item = old_item
new_item = new_item.replace("block.", "resnets.")
new_item = new_item.replace("conv_shorcut", "conv1")
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
new_item = new_item.replace("in_shortcut", "conv_shortcut")
new_item = new_item.replace("temb_proj", "time_emb_proj")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)

View File

@@ -0,0 +1,690 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Conversion script for the LDM checkpoints. """
import argparse
import os
import torch
try:
from omegaconf import OmegaConf
except ImportError:
raise ImportError(
"OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
)
from diffusers import (
AutoencoderKL,
DDIMScheduler,
LDMTextToImagePipeline,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
def shave_segments(path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
if n_shave_prefix_segments >= 0:
return ".".join(path.split(".")[n_shave_prefix_segments:])
else:
return ".".join(path.split(".")[:n_shave_prefix_segments])
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item.replace("in_layers.0", "norm1")
new_item = new_item.replace("in_layers.2", "conv1")
new_item = new_item.replace("out_layers.0", "norm2")
new_item = new_item.replace("out_layers.3", "conv2")
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
new_item = new_item.replace("skip_connection", "conv_shortcut")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
new_item = new_item.replace("norm.weight", "group_norm.weight")
new_item = new_item.replace("norm.bias", "group_norm.bias")
new_item = new_item.replace("q.weight", "query.weight")
new_item = new_item.replace("q.bias", "query.bias")
new_item = new_item.replace("k.weight", "key.weight")
new_item = new_item.replace("k.bias", "key.bias")
new_item = new_item.replace("v.weight", "value.weight")
new_item = new_item.replace("v.bias", "value.bias")
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def assign_to_checkpoint(
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
):
"""
This does the final conversion step: take locally converted weights and apply a global renaming
to them. It splits attention layers, and takes into account additional replacements
that may arise.
Assigns the weights to the new checkpoint.
"""
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
# Splits the attention layers into three variables.
if attention_paths_to_split is not None:
for path, path_map in attention_paths_to_split.items():
old_tensor = old_checkpoint[path]
channels = old_tensor.shape[0] // 3
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
query, key, value = old_tensor.split(channels // num_heads, dim=1)
checkpoint[path_map["query"]] = query.reshape(target_shape)
checkpoint[path_map["key"]] = key.reshape(target_shape)
checkpoint[path_map["value"]] = value.reshape(target_shape)
for path in paths:
new_path = path["new"]
# These have already been assigned
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
continue
# Global renaming happens here
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
if additional_replacements is not None:
for replacement in additional_replacements:
new_path = new_path.replace(replacement["old"], replacement["new"])
# proj_attn.weight has to be converted from conv 1D to linear
if "proj_attn.weight" in new_path:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
else:
checkpoint[new_path] = old_checkpoint[path["old"]]
def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys())
attn_keys = ["query.weight", "key.weight", "value.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
def create_unet_diffusers_config(original_config):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
unet_params = original_config.model.params.unet_config.params
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
down_block_types.append(block_type)
if i != len(block_out_channels) - 1:
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
up_block_types.append(block_type)
resolution //= 2
config = dict(
sample_size=unet_params.image_size,
in_channels=unet_params.in_channels,
out_channels=unet_params.out_channels,
down_block_types=tuple(down_block_types),
up_block_types=tuple(up_block_types),
block_out_channels=tuple(block_out_channels),
layers_per_block=unet_params.num_res_blocks,
cross_attention_dim=unet_params.context_dim,
attention_head_dim=unet_params.num_heads,
)
return config
def create_vae_diffusers_config(original_config):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
vae_params = original_config.model.params.first_stage_config.params.ddconfig
_ = original_config.model.params.first_stage_config.params.embed_dim
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
config = dict(
sample_size=vae_params.resolution,
in_channels=vae_params.in_channels,
out_channels=vae_params.out_ch,
down_block_types=tuple(down_block_types),
up_block_types=tuple(up_block_types),
block_out_channels=tuple(block_out_channels),
latent_channels=vae_params.z_channels,
layers_per_block=vae_params.num_res_blocks,
)
return config
def create_diffusers_schedular(original_config):
schedular = DDIMScheduler(
num_train_timesteps=original_config.model.params.timesteps,
beta_start=original_config.model.params.linear_start,
beta_end=original_config.model.params.linear_end,
beta_schedule="scaled_linear",
)
return schedular
def create_ldm_bert_config(original_config):
bert_params = original_config.model.parms.cond_stage_config.params
config = LDMBertConfig(
d_model=bert_params.n_embed,
encoder_layers=bert_params.n_layer,
encoder_ffn_dim=bert_params.n_embed * 4,
)
return config
def convert_ldm_unet_checkpoint(checkpoint, config):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
# extract state_dict for UNet
unet_state_dict = {}
unet_key = "model.diffusion_model."
keys = list(checkpoint.keys())
for key in keys:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
new_checkpoint = {}
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
# Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
input_blocks = {
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}
# Retrieves the keys for the middle blocks only
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
middle_blocks = {
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks)
}
# Retrieves the keys for the output blocks only
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
output_blocks = {
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
for layer_id in range(num_output_blocks)
}
for i in range(1, num_input_blocks):
block_id = (i - 1) // (config["layers_per_block"] + 1)
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
resnets = [
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
]
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.bias"
)
paths = renew_resnet_paths(resnets)
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
resnet_0 = middle_blocks[0]
attentions = middle_blocks[1]
resnet_1 = middle_blocks[2]
resnet_0_paths = renew_resnet_paths(resnet_0)
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
resnet_1_paths = renew_resnet_paths(resnet_1)
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
attentions_paths = renew_attention_paths(attentions)
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
for i in range(num_output_blocks):
block_id = i // (config["layers_per_block"] + 1)
layer_in_block_id = i % (config["layers_per_block"] + 1)
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
output_block_list = {}
for layer in output_block_layers:
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
if layer_id in output_block_list:
output_block_list[layer_id].append(layer_name)
else:
output_block_list[layer_id] = [layer_name]
if len(output_block_list) > 1:
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
resnet_0_paths = renew_resnet_paths(resnets)
paths = renew_resnet_paths(resnets)
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
if ["conv.weight", "conv.bias"] in output_block_list.values():
index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.weight"
]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.bias"
]
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
attentions = []
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {
"old": f"output_blocks.{i}.1",
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
else:
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
for path in resnet_0_paths:
old_path = ".".join(["output_blocks", str(i), path["old"]])
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
new_checkpoint[new_path] = unet_state_dict[old_path]
return new_checkpoint
def convert_ldm_vae_checkpoint(checkpoint, config):
# extract state dict for VAE
vae_state_dict = {}
vae_key = "first_stage_model."
keys = list(checkpoint.keys())
for key in keys:
if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
new_checkpoint = {}
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
# Retrieves the keys for the encoder down blocks only
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
down_blocks = {
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
# Retrieves the keys for the decoder up blocks only
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
up_blocks = {
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
}
for i in range(num_down_blocks):
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.weight"
)
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.bias"
)
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
conv_attn_to_linear(new_checkpoint)
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
resnets = [
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
]
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.weight"
]
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.bias"
]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
conv_attn_to_linear(new_checkpoint)
return new_checkpoint
def convert_ldm_bert_checkpoint(checkpoint, config):
def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
def _copy_linear(hf_linear, pt_linear):
hf_linear.weight = pt_linear.weight
hf_linear.bias = pt_linear.bias
def _copy_layer(hf_layer, pt_layer):
# copy layer norms
_copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
_copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
# copy attn
_copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
# copy MLP
pt_mlp = pt_layer[1][1]
_copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
_copy_linear(hf_layer.fc2, pt_mlp.net[2])
def _copy_layers(hf_layers, pt_layers):
for i, hf_layer in enumerate(hf_layers):
if i != 0:
i += i
pt_layer = pt_layers[i : i + 2]
_copy_layer(hf_layer, pt_layer)
hf_model = LDMBertModel(config).eval()
# copy embeds
hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
# copy layer norm
_copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
# copy hidden layers
_copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
_copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
return hf_model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
)
# !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
parser.add_argument(
"--original_config_file",
default=None,
type=str,
help="The YAML config file corresponding to the original architecture.",
)
parser.add_argument(
"--scheduler_type",
default="pndm",
type=str,
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']",
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
args = parser.parse_args()
if args.original_config_file is None:
os.system(
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
)
args.original_config_file = "./v1-inference.yaml"
original_config = OmegaConf.load(args.original_config_file)
checkpoint = torch.load(args.checkpoint_path)["state_dict"]
num_train_timesteps = original_config.model.params.timesteps
beta_start = original_config.model.params.linear_start
beta_end = original_config.model.params.linear_end
if args.scheduler_type == "pndm":
scheduler = PNDMScheduler(
beta_end=beta_end,
beta_schedule="scaled_linear",
beta_start=beta_start,
num_train_timesteps=num_train_timesteps,
skip_prk_steps=True,
)
elif args.scheduler_type == "lms":
scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
elif args.scheduler_type == "ddim":
scheduler = DDIMScheduler(
beta_start=beta_start,
beta_end=beta_end,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
else:
raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config)
converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config)
unet = UNet2DConditionModel(**unet_config)
unet.load_state_dict(converted_unet_checkpoint)
# Convert the VAE model.
vae_config = create_vae_diffusers_config(original_config)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
# Convert the text model.
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
if text_model_type == "FrozenCLIPEmbedder":
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
else:
text_config = create_ldm_bert_config(original_config)
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
pipe.save_pretrained(args.dump_path)

View File

@@ -0,0 +1,196 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from pathlib import Path
import torch
from torch.onnx import export
from diffusers import StableDiffusionOnnxPipeline, StableDiffusionPipeline
from diffusers.onnx_utils import OnnxRuntimeModel
from packaging import version
is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11")
def onnx_export(
model,
model_args: tuple,
output_path: Path,
ordered_input_names,
output_names,
dynamic_axes,
opset,
use_external_data_format=False,
):
output_path.parent.mkdir(parents=True, exist_ok=True)
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility
if is_torch_less_than_1_11:
export(
model,
model_args,
f=output_path.as_posix(),
input_names=ordered_input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
use_external_data_format=use_external_data_format,
enable_onnx_checker=True,
opset_version=opset,
)
else:
export(
model,
model_args,
f=output_path.as_posix(),
input_names=ordered_input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
opset_version=opset,
)
@torch.no_grad()
def convert_models(model_path: str, output_path: str, opset: int):
pipeline = StableDiffusionPipeline.from_pretrained(model_path, use_auth_token=True)
output_path = Path(output_path)
# TEXT ENCODER
text_input = pipeline.tokenizer(
"A sample prompt",
padding="max_length",
max_length=pipeline.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
onnx_export(
pipeline.text_encoder,
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
model_args=(text_input.input_ids.to(torch.int32)),
output_path=output_path / "text_encoder" / "model.onnx",
ordered_input_names=["input_ids"],
output_names=["last_hidden_state", "pooler_output"],
dynamic_axes={
"input_ids": {0: "batch", 1: "sequence"},
},
opset=opset,
)
# UNET
onnx_export(
pipeline.unet,
model_args=(torch.randn(2, 4, 64, 64), torch.LongTensor([0, 1]), torch.randn(2, 77, 768), False),
output_path=output_path / "unet" / "model.onnx",
ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"],
output_names=["out_sample"], # has to be different from "sample" for correct tracing
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"timestep": {0: "batch"},
"encoder_hidden_states": {0: "batch", 1: "sequence"},
},
opset=opset,
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split
)
# VAE ENCODER
vae_encoder = pipeline.vae
# need to get the raw tensor output (sample) from the encoder
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample()
onnx_export(
vae_encoder,
model_args=(torch.randn(1, 3, 512, 512), False),
output_path=output_path / "vae_encoder" / "model.onnx",
ordered_input_names=["sample", "return_dict"],
output_names=["latent_sample"],
dynamic_axes={
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=opset,
)
# VAE DECODER
vae_decoder = pipeline.vae
# forward only through the decoder part
vae_decoder.forward = vae_encoder.decode
onnx_export(
vae_decoder,
model_args=(torch.randn(1, 4, 64, 64), False),
output_path=output_path / "vae_decoder" / "model.onnx",
ordered_input_names=["latent_sample", "return_dict"],
output_names=["sample"],
dynamic_axes={
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=opset,
)
# SAFETY CHECKER
safety_checker = pipeline.safety_checker
safety_checker.forward = safety_checker.forward_onnx
onnx_export(
pipeline.safety_checker,
model_args=(torch.randn(1, 3, 224, 224), torch.randn(1, 512, 512, 3)),
output_path=output_path / "safety_checker" / "model.onnx",
ordered_input_names=["clip_input", "images"],
output_names=["out_images", "has_nsfw_concepts"],
dynamic_axes={
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
"images": {0: "batch", 1: "channels", 2: "height", 3: "width"},
},
opset=opset,
)
onnx_pipeline = StableDiffusionOnnxPipeline(
vae_decoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_decoder"),
text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"),
tokenizer=pipeline.tokenizer,
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler,
safety_checker=OnnxRuntimeModel.from_pretrained(output_path / "safety_checker"),
feature_extractor=pipeline.feature_extractor,
)
onnx_pipeline.save_pretrained(output_path)
print("ONNX pipeline saved to", output_path)
_ = StableDiffusionOnnxPipeline.from_pretrained(output_path, provider="CPUExecutionProvider")
print("ONNX pipeline is loadable")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path",
type=str,
required=True,
help="Path to the `diffusers` checkpoint to convert (either a local directory or on the Hub).",
)
parser.add_argument("--output_path", type=str, required=True, help="Path to the output model.")
parser.add_argument(
"--opset",
default=14,
type=str,
help="The version of the ONNX operator set to use.",
)
args = parser.parse_args()
convert_models(args.model_path, args.output_path, args.opset)

View File

@@ -124,4 +124,4 @@ for mod in models:
assert torch.allclose(
logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3
)
print(f"{mod.modelId} has passed succesfully!!!")
print(f"{mod.modelId} has passed successfully!!!")

View File

@@ -68,6 +68,7 @@ To create the package for pypi.
"""
import re
import os
from distutils.core import Command
from setuptools import find_packages, setup
@@ -78,14 +79,17 @@ from setuptools import find_packages, setup
_deps = [
"Pillow",
"accelerate>=0.11.0",
"black==22.3",
"black==22.8",
"datasets",
"filelock",
"flake8>=3.8.3",
"flax>=0.4.1",
"hf-doc-builder>=0.3.0",
"huggingface-hub>=0.8.1",
"huggingface-hub>=0.9.1",
"importlib_metadata",
"isort>=5.5.4",
"jax>=0.2.8,!=0.3.2,<=0.3.6",
"jaxlib>=0.1.65,<=0.3.6",
"modelcards==0.1.4",
"numpy",
"pytest",
@@ -167,11 +171,18 @@ extras = {}
extras = {}
extras["quality"] = ["black==22.3", "isort>=5.5.4", "flake8>=3.8.3", "hf-doc-builder"]
extras["quality"] = ["black==22.8", "isort>=5.5.4", "flake8>=3.8.3", "hf-doc-builder"]
extras["docs"] = ["hf-doc-builder"]
extras["training"] = ["accelerate", "datasets", "tensorboard", "modelcards"]
extras["test"] = ["datasets", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "transformers"]
extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"]
extras["test"] = ["datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "transformers"]
extras["torch"] = deps_list("torch")
if os.name == "nt": # windows
extras["flax"] = [] # jax is not supported on windows
else:
extras["flax"] = deps_list("jax", "jaxlib", "flax")
extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]
install_requires = [
deps["importlib_metadata"],
@@ -180,13 +191,12 @@ install_requires = [
deps["numpy"],
deps["regex"],
deps["requests"],
deps["torch"],
deps["Pillow"],
]
setup(
name="diffusers",
version="0.3.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.4.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="Diffusers",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
@@ -198,7 +208,7 @@ setup(
package_dir={"": "src"},
packages=find_packages("src"),
include_package_data=True,
python_requires=">=3.6.0",
python_requires=">=3.7.0",
install_requires=install_requires,
extras_require=extras,
entry_points={"console_scripts": ["diffusers-cli=diffusers.commands.diffusers_cli:main"]},

View File

@@ -1,41 +1,53 @@
from .utils import is_inflect_available, is_scipy_available, is_transformers_available, is_unidecode_available
__version__ = "0.3.0.dev0"
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from .optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup,
get_linear_schedule_with_warmup,
get_polynomial_decay_schedule_with_warmup,
get_scheduler,
)
from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
from .schedulers import (
DDIMScheduler,
DDPMScheduler,
KarrasVeScheduler,
PNDMScheduler,
SchedulerMixin,
ScoreSdeVeScheduler,
from .utils import (
is_flax_available,
is_inflect_available,
is_onnx_available,
is_scipy_available,
is_torch_available,
is_transformers_available,
is_unidecode_available,
)
__version__ = "0.4.0.dev0"
from .configuration_utils import ConfigMixin
from .onnx_utils import OnnxRuntimeModel
from .utils import logging
if is_scipy_available():
if is_torch_available():
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from .optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup,
get_linear_schedule_with_warmup,
get_polynomial_decay_schedule_with_warmup,
get_scheduler,
)
from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
from .schedulers import (
DDIMScheduler,
DDPMScheduler,
KarrasVeScheduler,
PNDMScheduler,
SchedulerMixin,
ScoreSdeVeScheduler,
)
from .training_utils import EMAModel
else:
from .utils.dummy_pt_objects import * # noqa F403
if is_torch_available() and is_scipy_available():
from .schedulers import LMSDiscreteScheduler
else:
from .utils.dummy_scipy_objects import * # noqa F403
from .utils.dummy_torch_and_scipy_objects import * # noqa F403
from .training_utils import EMAModel
if is_transformers_available():
if is_torch_available() and is_transformers_available():
from .pipelines import (
LDMTextToImagePipeline,
StableDiffusionImg2ImgPipeline,
@@ -43,4 +55,23 @@ if is_transformers_available():
StableDiffusionPipeline,
)
else:
from .utils.dummy_transformers_objects import * # noqa F403
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
if is_torch_available() and is_transformers_available() and is_onnx_available():
from .pipelines import StableDiffusionOnnxPipeline
else:
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
if is_flax_available():
from .modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .schedulers import (
FlaxDDIMScheduler,
FlaxDDPMScheduler,
FlaxKarrasVeScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
FlaxScoreSdeVeScheduler,
)
else:
from .utils.dummy_flax_objects import * # noqa F403

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" ConfigMixinuration base class and utilities."""
import dataclasses
import functools
import inspect
import json
@@ -37,9 +38,16 @@ _re_configuration_file = re.compile(r"config\.(.*)\.json")
class ConfigMixin:
r"""
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
methods for loading/downloading/saving configurations.
Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
- [`~ConfigMixin.from_config`]
- [`~ConfigMixin.save_config`]
Class attributes:
- **config_name** (`str`) -- A filename under which the config should stored when calling
[`~ConfigMixin.save_config`] (should be overridden by parent class).
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
overridden by parent class).
"""
config_name = None
ignore_for_config = []
@@ -74,8 +82,6 @@ class ConfigMixin:
Args:
save_directory (`str` or `os.PathLike`):
Directory where the configuration JSON file will be saved (will be created if it does not exist).
kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
@@ -90,6 +96,63 @@ class ConfigMixin:
@classmethod
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
r"""
Instantiate a Python class from a pre-defined JSON-file.
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
organization name, like `google/ddpm-celebahq-256`.
- A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g.,
`./my_model_directory/`.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
checkpoint with 3 labels).
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo (either remote in
huggingface.co or downloaded locally), you can specify the folder name here.
<Tip>
Passing `use_auth_token=True`` is required when you want to use a private model.
</Tip>
<Tip>
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
use this method in a firewalled environment.
</Tip>
"""
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
@@ -112,6 +175,7 @@ class ConfigMixin:
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
_ = kwargs.pop("mirror", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = {"file_type": "config"}
@@ -208,6 +272,11 @@ class ConfigMixin:
# remove general kwargs if present in dict
if "kwargs" in expected_keys:
expected_keys.remove("kwargs")
# remove flax interal keys
if hasattr(cls, "_flax_internal_args"):
for arg in cls._flax_internal_args:
expected_keys.remove(arg)
# remove keys to be ignored
if len(cls.ignore_for_config) > 0:
expected_keys = expected_keys - set(cls.ignore_for_config)
@@ -220,11 +289,20 @@ class ConfigMixin:
# use value from config dict
init_dict[key] = config_dict.pop(key)
unused_kwargs = config_dict.update(kwargs)
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
if len(config_dict) > 0:
logger.warning(
f"The config attributes {config_dict} were passed to {cls.__name__}, "
"but are not expected and will be ignored. Please verify your "
f"{cls.config_name} configuration file."
)
unused_kwargs = {**config_dict, **kwargs}
passed_keys = set(init_dict.keys())
if len(expected_keys - passed_keys) > 0:
logger.warning(
logger.info(
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
)
@@ -298,10 +376,10 @@ class FrozenDict(OrderedDict):
def register_to_config(init):
"""
Decorator to apply on the init of classes inheriting from `ConfigMixin` so that all the arguments are automatically
sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that shouldn't be
registered in the config, use the `ignore_for_config` class variable
r"""
Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
shouldn't be registered in the config, use the `ignore_for_config` class variable
Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
"""
@@ -338,3 +416,44 @@ def register_to_config(init):
getattr(self, "register_to_config")(**new_kwargs)
return inner_init
def flax_register_to_config(cls):
original_init = cls.__init__
@functools.wraps(original_init)
def init(self, *args, **kwargs):
if not isinstance(self, ConfigMixin):
raise RuntimeError(
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
"not inherit from `ConfigMixin`."
)
# Ignore private kwargs in the init. Retrieve all passed attributes
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
# Retrieve default values
fields = dataclasses.fields(self)
default_kwargs = {}
for field in fields:
# ignore flax specific attributes
if field.name in self._flax_internal_args:
continue
if type(field.default) == dataclasses._MISSING_TYPE:
default_kwargs[field.name] = None
else:
default_kwargs[field.name] = getattr(self, field.name)
# Make sure init_kwargs override default kwargs
new_kwargs = {**default_kwargs, **init_kwargs}
# Get positional arguments aligned with kwargs
for i, arg in enumerate(args):
name = fields[i].name
new_kwargs[name] = arg
getattr(self, "register_to_config")(**new_kwargs)
original_init(self, *args, **kwargs)
cls.__init__ = init
return cls

View File

@@ -4,14 +4,17 @@
deps = {
"Pillow": "Pillow",
"accelerate": "accelerate>=0.11.0",
"black": "black==22.3",
"black": "black==22.8",
"datasets": "datasets",
"filelock": "filelock",
"flake8": "flake8>=3.8.3",
"flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.8.1",
"huggingface-hub": "huggingface-hub>=0.9.1",
"importlib_metadata": "importlib_metadata",
"isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2,<=0.3.6",
"jaxlib": "jaxlib>=0.1.65,<=0.3.6",
"modelcards": "modelcards==0.1.4",
"numpy": "numpy",
"pytest": "pytest",

View File

@@ -0,0 +1,526 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pickle import UnpicklingError
from typing import Any, Dict, Union
import jax
import jax.numpy as jnp
import msgpack.exceptions
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.serialization import from_bytes, to_bytes
from flax.traverse_util import flatten_dict, unflatten_dict
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError
from .modeling_utils import WEIGHTS_NAME
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
logger = logging.get_logger(__name__)
class FlaxModelMixin:
r"""
Base class for all flax models.
[`FlaxModelMixin`] takes care of storing the configuration of the models and handles methods for loading,
downloading and saving models.
"""
config_name = CONFIG_NAME
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
_flax_internal_args = ["name", "parent"]
@classmethod
def _from_config(cls, config, **kwargs):
"""
All context managers that the model should be initialized under go here.
"""
return cls(config, **kwargs)
def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
"""
Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
"""
# taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
def conditional_cast(param):
if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
param = param.astype(dtype)
return param
if mask is None:
return jax.tree_map(conditional_cast, params)
flat_params = flatten_dict(params)
flat_mask, _ = jax.tree_flatten(mask)
for masked, key in zip(flat_mask, flat_params.keys()):
if masked:
param = flat_params[key]
flat_params[key] = conditional_cast(param)
return unflatten_dict(flat_params)
def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
the `params` in place.
This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full
half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
Arguments:
params (`Union[Dict, FrozenDict]`):
A `PyTree` of model parameters.
mask (`Union[Dict, FrozenDict]`):
A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
you want to cast, and should be `False` for those you want to skip.
Examples:
```python
>>> from diffusers import FlaxUNet2DConditionModel
>>> # load model
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4")
>>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
>>> params = model.to_bf16(params)
>>> # If you don't want to cast certain parameters (for example layer norm bias and scale)
>>> # then pass the mask as follows
>>> from flax import traverse_util
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4")
>>> flat_params = traverse_util.flatten_dict(params)
>>> mask = {
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
... for path in flat_params
... }
>>> mask = traverse_util.unflatten_dict(mask)
>>> params = model.to_bf16(params, mask)
```"""
return self._cast_floating_to(params, jnp.bfloat16, mask)
def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the
model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
Arguments:
params (`Union[Dict, FrozenDict]`):
A `PyTree` of model parameters.
mask (`Union[Dict, FrozenDict]`):
A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
you want to cast, and should be `False` for those you want to skip
Examples:
```python
>>> from diffusers import FlaxUNet2DConditionModel
>>> # Download model and configuration from huggingface.co
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4")
>>> # By default, the model params will be in fp32, to illustrate the use of this method,
>>> # we'll first cast to fp16 and back to fp32
>>> params = model.to_f16(params)
>>> # now cast back to fp32
>>> params = model.to_fp32(params)
```"""
return self._cast_floating_to(params, jnp.float32, mask)
def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
r"""
Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
`params` in place.
This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full
half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
Arguments:
params (`Union[Dict, FrozenDict]`):
A `PyTree` of model parameters.
mask (`Union[Dict, FrozenDict]`):
A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
you want to cast, and should be `False` for those you want to skip
Examples:
```python
>>> from diffusers import FlaxUNet2DConditionModel
>>> # load model
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4")
>>> # By default, the model params will be in fp32, to cast these to float16
>>> params = model.to_fp16(params)
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
>>> # then pass the mask as follows
>>> from flax import traverse_util
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4")
>>> flat_params = traverse_util.flatten_dict(params)
>>> mask = {
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
... for path in flat_params
... }
>>> mask = traverse_util.unflatten_dict(mask)
>>> params = model.to_fp16(params, mask)
```"""
return self._cast_floating_to(params, jnp.float16, mask)
def init_weights(self, rng: jax.random.PRNGKey) -> Dict:
raise NotImplementedError(f"init_weights method has to be implemented for {self}")
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
dtype: jnp.dtype = jnp.float32,
*model_args,
**kwargs,
):
r"""
Instantiate a pretrained flax model from a pre-trained model configuration.
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
task.
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
weights are discarded.
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`):
Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids are namespaced under a user or organization name, like
`CompVis/stable-diffusion-v1-4`.
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`],
e.g., `./my_model_directory/`.
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given `dtype`.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see [`~ModelMixin.to_fp16`] and
[`~ModelMixin.to_bf16`].
model_args (sequence of positional arguments, *optional*):
All remaining positional arguments will be passed to the underlying model's `__init__` method.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
automatically loaded:
- If a configuration is provided with `config`, `**kwargs` will be directly passed to the
underlying model's `__init__` method (we assume all relevant updates to the configuration have
already been done)
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
initialization function ([`~ConfigMixin.from_config`]). Each key of `kwargs` that corresponds to
a configuration attribute will be used to override said attribute with the supplied `kwargs`
value. Remaining keys that do not correspond to any configuration attribute will be passed to the
underlying model's `__init__` function.
Examples:
```python
>>> from diffusers import FlaxUNet2DConditionModel
>>> # Download model and configuration from huggingface.co and cache.
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4")
>>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/")
```"""
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
from_auto_class = kwargs.pop("_from_auto", False)
subfolder = kwargs.pop("subfolder", None)
user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class}
# Load config if we don't provide a configuration
config_path = config if config is not None else pretrained_model_name_or_path
model, model_kwargs = cls.from_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
# model args
dtype=dtype,
**kwargs,
)
# Load model
if os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
# Load from a Flax checkpoint
model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
# At this stage we don't have a weight file so we will raise an error.
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
raise EnvironmentError(
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
"but there is a file for PyTorch weights."
)
else:
raise EnvironmentError(
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
f"{pretrained_model_name_or_path}."
)
else:
try:
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=FLAX_WEIGHTS_NAME,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
)
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login` and pass `use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
f"{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your"
" internet connection or see how to run the library in offline mode at"
" 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
)
try:
with open(model_file, "rb") as state_f:
state = from_bytes(cls, state_f.read())
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
try:
with open(model_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please"
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
" folder you cloned."
)
else:
raise ValueError from e
except (UnicodeDecodeError, ValueError):
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
# make sure all arrays are stored as jnp.ndarray
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
# https://github.com/google/flax/issues/1261
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
# flatten dicts
state = flatten_dict(state)
params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0))
required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
shape_state = flatten_dict(unfreeze(params_shape_tree))
missing_keys = required_params - set(state.keys())
unexpected_keys = set(state.keys()) - required_params
if missing_keys:
logger.warning(
f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
"Make sure to call model.init_weights to initialize the missing weights."
)
cls._missing_keys = missing_keys
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys = []
for key in state.keys():
if key in shape_state and state[key].shape != shape_state[key].shape:
raise ValueError(
f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. "
)
# remove unexpected keys to not be saved again
for unexpected_key in unexpected_keys:
del state[unexpected_key]
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
" with another architecture."
)
else:
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
elif len(mismatched_keys) == 0:
logger.info(
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
" training."
)
if len(mismatched_keys) > 0:
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
for key, shape1, shape2 in mismatched_keys
]
)
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
" to use it for predictions and inference."
)
# dictionary of key: dtypes for the model params
param_dtypes = jax.tree_map(lambda x: x.dtype, state)
# extract keys of parameters not in jnp.float32
fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16]
bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16]
# raise a warning if any of the parameters are not in jnp.float32
if len(fp16_params) > 0:
logger.warning(
f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from "
f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n"
"You should probably UPCAST the model weights to float32 if this was not intended. "
"See [`~ModelMixin.to_fp32`] for further information on how to do this."
)
if len(bf16_params) > 0:
logger.warning(
f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from "
f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n"
"You should probably UPCAST the model weights to float32 if this was not intended. "
"See [`~ModelMixin.to_fp32`] for further information on how to do this."
)
return model, unflatten_dict(state)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
params: Union[Dict, FrozenDict],
is_main_process: bool = True,
):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`[`~FlaxModelMixin.from_pretrained`]` class method
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to which to save. Will be created if it doesn't exist.
params (`Union[Dict, FrozenDict]`):
A `PyTree` of model parameters.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful when in distributed training like
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
the main process to avoid race conditions.
"""
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
os.makedirs(save_directory, exist_ok=True)
model_to_save = self
# Attach architecture to the config
# Save the config
if is_main_process:
model_to_save.save_config(save_directory)
# save model
output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
with open(output_model_file, "wb") as f:
model_bytes = to_bytes(params)
f.write(model_bytes)
logger.info(f"Model weights saved in {output_model_file}")

View File

@@ -117,27 +117,10 @@ class ModelMixin(torch.nn.Module):
Base class for all models.
[`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
and saving models as well as a few methods common to all models to:
and saving models.
- resize the input embeddings,
- prune heads in the self-attention heads.
Class attributes (overridden by derived classes):
- **config_class** ([`ConfigMixin`]) -- A subclass of [`ConfigMixin`] to use as configuration class for this
model architecture.
- **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
taking as arguments:
- **model** ([`ModelMixin`]) -- An instance of the model on which to load the TensorFlow checkpoint.
- **config** ([`PreTrainedConfigMixin`]) -- An instance of the configuration associated to the model.
- **path** (`str`) -- A path to the TensorFlow checkpoint.
- **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
classes of the same architecture adding modules on top of the base model.
- **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization.
- **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
models, `pixel_values` for vision models and `input_values` for speech models).
- **config_name** ([`str`]) -- A filename under which the model should be stored when calling
[`~modeling_utils.ModelMixin.save_pretrained`].
"""
config_name = CONFIG_NAME
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
@@ -150,11 +133,10 @@ class ModelMixin(torch.nn.Module):
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
save_function: Callable = torch.save,
**kwargs,
):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`[`~ModelMixin.from_pretrained`]` class method.
`[`~modeling_utils.ModelMixin.from_pretrained`]` class method.
Arguments:
save_directory (`str` or `os.PathLike`):
@@ -166,9 +148,6 @@ class ModelMixin(torch.nn.Module):
save_function (`Callable`):
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
need to replace `torch.save` by another method.
kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
@@ -219,39 +198,16 @@ class ModelMixin(torch.nn.Module):
Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
user or organization name, like `dbmdz/bert-base-german-cased`.
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`],
e.g., `./my_model_directory/`.
Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
`./my_model_directory/`.
config (`Union[ConfigMixin, str, os.PathLike]`, *optional*):
Can be either:
- an instance of a class derived from [`ConfigMixin`],
- a string or path valid as input to [`~ConfigMixin.from_pretrained`].
ConfigMixinuration for the model to use instead of an automatically loaded configuration.
ConfigMixinuration can be automatically loaded when:
- The model is a model provided by the library (loaded with the *model id* string of a pretrained
model).
- The model was saved using [`~ModelMixin.save_pretrained`] and is reloaded by supplying the save
directory.
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
configuration JSON file named *config.json* is found in the directory.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
from_tf (`bool`, *optional*, defaults to `False`):
Load the model weights from a TensorFlow checkpoint save file (see docstring of
`pretrained_model_name_or_path` argument).
from_flax (`bool`, *optional*, defaults to `False`):
Load the model weights from a Flax checkpoint save file (see docstring of
`pretrained_model_name_or_path` argument).
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
checkpoint with 3 labels).
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
will be automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
@@ -262,35 +218,25 @@ class ModelMixin(torch.nn.Module):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
when running `diffusers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo (either remote in
huggingface.co or downloaded locally), you can specify the folder name here.
mirror (`str`, *optional*):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
automatically loaded:
- If a configuration is provided with `config`, `**kwargs` will be directly passed to the
underlying model's `__init__` method (we assume all relevant updates to the configuration have
already been done)
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
initialization function ([`~ConfigMixin.from_pretrained`]). Each key of `kwargs` that corresponds
to a configuration attribute will be used to override said attribute with the supplied `kwargs`
value. Remaining keys that do not correspond to any configuration attribute will be passed to the
underlying model's `__init__` function.
<Tip>
Passing `use_auth_token=True`` is required when you want to use a private model.
@@ -299,8 +245,8 @@ class ModelMixin(torch.nn.Module):
<Tip>
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
use this method in a firewalled environment.
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
this method in a firewalled environment.
</Tip>
@@ -404,7 +350,7 @@ class ModelMixin(torch.nn.Module):
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {WEIGHTS_NAME} or"
" \nCheckout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(

View File

@@ -1,4 +1,5 @@
import math
from typing import Optional
import torch
import torch.nn.functional as F
@@ -10,16 +11,24 @@ class AttentionBlock(nn.Module):
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Uses three q, k, v linear layers to compute attention
Uses three q, k, v linear layers to compute attention.
Parameters:
channels (:obj:`int`): The number of channels in the input and output.
num_head_channels (:obj:`int`, *optional*):
The number of channels in each head. If None, then `num_heads` = 1.
num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
"""
def __init__(
self,
channels,
num_head_channels=None,
num_groups=32,
rescale_output_factor=1.0,
eps=1e-5,
channels: int,
num_head_channels: Optional[int] = None,
num_groups: int = 32,
rescale_output_factor: float = 1.0,
eps: float = 1e-5,
):
super().__init__()
self.channels = channels
@@ -86,16 +95,33 @@ class AttentionBlock(nn.Module):
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
standard transformer action. Finally, reshape to image
standard transformer action. Finally, reshape to image.
Parameters:
in_channels (:obj:`int`): The number of channels in the input and output.
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
d_head (:obj:`int`): The number of channels in each head.
depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
"""
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
def __init__(
self,
in_channels: int,
n_heads: int,
d_head: int,
depth: int = 1,
dropout: float = 0.0,
num_groups: int = 32,
context_dim: Optional[int] = None,
):
super().__init__()
self.n_heads = n_heads
self.d_head = d_head
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
@@ -112,22 +138,44 @@ class SpatialTransformer(nn.Module):
for block in self.transformer_blocks:
block._set_attention_slice(slice_size)
def forward(self, x, context=None):
def forward(self, hidden_states, context=None):
# note: if no context is given, cross-attention defaults to self-attention
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
for block in self.transformer_blocks:
x = block(x, context=context)
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
x = self.proj_out(x)
return x + x_in
hidden_states = block(hidden_states, context=context)
hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2)
hidden_states = self.proj_out(hidden_states)
return hidden_states + residual
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
r"""
A basic Transformer block.
Parameters:
dim (:obj:`int`): The number of channels in the input and output.
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
d_head (:obj:`int`): The number of channels in each head.
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
"""
def __init__(
self,
dim: int,
n_heads: int,
d_head: int,
dropout=0.0,
context_dim: Optional[int] = None,
gated_ff: bool = True,
checkpoint: bool = True,
):
super().__init__()
self.attn1 = CrossAttention(
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
@@ -145,15 +193,30 @@ class BasicTransformerBlock(nn.Module):
self.attn1._slice_size = slice_size
self.attn2._slice_size = slice_size
def forward(self, x, context=None):
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
def forward(self, hidden_states, context=None):
hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
return hidden_states
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
r"""
A cross attention layer.
Parameters:
query_dim (:obj:`int`): The number of channels in the query.
context_dim (:obj:`int`, *optional*):
The number of channels in the context. If not given, defaults to `query_dim`.
heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
"""
def __init__(
self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0
):
super().__init__()
inner_dim = dim_head * heads
context_dim = context_dim if context_dim is not None else query_dim
@@ -185,26 +248,39 @@ class CrossAttention(nn.Module):
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def forward(self, x, context=None, mask=None):
batch_size, sequence_length, dim = x.shape
def forward(self, hidden_states, context=None, mask=None):
batch_size, sequence_length, dim = hidden_states.shape
q = self.to_q(x)
context = context if context is not None else x
k = self.to_k(context)
v = self.to_v(context)
query = self.to_q(hidden_states)
context = context if context is not None else hidden_states
key = self.to_k(context)
value = self.to_v(context)
q = self.reshape_heads_to_batch_dim(q)
k = self.reshape_heads_to_batch_dim(k)
v = self.reshape_heads_to_batch_dim(v)
query = self.reshape_heads_to_batch_dim(query)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
# attention, what we cannot get enough of
hidden_states = self._attention(q, k, v, sequence_length, dim)
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
return self.to_out(hidden_states)
def _attention(self, query, key, value, sequence_length, dim):
def _attention(self, query, key, value):
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
attention_probs = attention_scores.softmax(dim=-1)
# compute attention output
hidden_states = torch.matmul(attention_probs, value)
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
def _sliced_attention(self, query, key, value, sequence_length, dim):
batch_size_attention = query.shape[0]
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
@@ -213,11 +289,9 @@ class CrossAttention(nn.Module):
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
attn_slice = (
torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale
)
attn_slice = torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
attn_slice = attn_slice.softmax(dim=-1)
attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
@@ -227,7 +301,20 @@ class CrossAttention(nn.Module):
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
r"""
A feed-forward layer.
Parameters:
dim (:obj:`int`): The number of channels in the input.
dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
"""
def __init__(
self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
@@ -235,16 +322,24 @@ class FeedForward(nn.Module):
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.net(x)
def forward(self, hidden_states):
return self.net(hidden_states)
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
r"""
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
Parameters:
dim_in (:obj:`int`): The number of channels in the input.
dim_out (:obj:`int`): The number of channels in the output.
"""
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * F.gelu(gate)

View File

@@ -0,0 +1,180 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import flax.linen as nn
import jax.numpy as jnp
class FlaxAttentionBlock(nn.Module):
query_dim: int
heads: int = 8
dim_head: int = 64
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
def setup(self):
inner_dim = self.dim_head * self.heads
self.scale = self.dim_head**-0.5
# Weights were exported with old names {to_q, to_k, to_v, to_out}
self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out")
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def __call__(self, hidden_states, context=None, deterministic=True):
context = hidden_states if context is None else context
query_proj = self.query(hidden_states)
key_proj = self.key(context)
value_proj = self.value(context)
query_states = self.reshape_heads_to_batch_dim(query_proj)
key_states = self.reshape_heads_to_batch_dim(key_proj)
value_states = self.reshape_heads_to_batch_dim(value_proj)
# compute attentions
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
attention_scores = attention_scores * self.scale
attention_probs = nn.softmax(attention_scores, axis=2)
# attend to values
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
hidden_states = self.proj_attn(hidden_states)
return hidden_states
class FlaxBasicTransformerBlock(nn.Module):
dim: int
n_heads: int
d_head: int
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
def setup(self):
# self attention
self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
# cross attention
self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
def __call__(self, hidden_states, context, deterministic=True):
# self attention
residual = hidden_states
hidden_states = self.self_attn(self.norm1(hidden_states), deterministic=deterministic)
hidden_states = hidden_states + residual
# cross attention
residual = hidden_states
hidden_states = self.cross_attn(self.norm2(hidden_states), context, deterministic=deterministic)
hidden_states = hidden_states + residual
# feed forward
residual = hidden_states
hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
hidden_states = hidden_states + residual
return hidden_states
class FlaxSpatialTransformer(nn.Module):
in_channels: int
n_heads: int
d_head: int
depth: int = 1
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
def setup(self):
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
inner_dim = self.n_heads * self.d_head
self.proj_in = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
self.transformer_blocks = [
FlaxBasicTransformerBlock(inner_dim, self.n_heads, self.d_head, dropout=self.dropout, dtype=self.dtype)
for _ in range(self.depth)
]
self.proj_out = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
def __call__(self, hidden_states, context, deterministic=True):
batch, height, width, channels = hidden_states.shape
# import ipdb; ipdb.set_trace()
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.reshape(batch, height * width, channels)
for transformer_block in self.transformer_blocks:
hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
hidden_states = hidden_states.reshape(batch, height, width, channels)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
class FlaxGluFeedForward(nn.Module):
dim: int
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
def setup(self):
inner_dim = self.dim * 4
self.dense1 = nn.Dense(inner_dim * 2, dtype=self.dtype)
self.dense2 = nn.Dense(self.dim, dtype=self.dtype)
def __call__(self, hidden_states, deterministic=True):
hidden_states = self.dense1(hidden_states)
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
hidden_states = hidden_linear * nn.gelu(hidden_gelu)
hidden_states = self.dense2(hidden_states)
return hidden_states

View File

@@ -19,7 +19,12 @@ from torch import nn
def get_timestep_embedding(
timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, scale=1, max_period=10000
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
@@ -55,7 +60,7 @@ def get_timestep_embedding(
class TimestepEmbedding(nn.Module):
def __init__(self, channel, time_embed_dim, act_fn="silu"):
def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"):
super().__init__()
self.linear_1 = nn.Linear(channel, time_embed_dim)
@@ -75,7 +80,7 @@ class TimestepEmbedding(nn.Module):
class Timesteps(nn.Module):
def __init__(self, num_channels, flip_sin_to_cos, downscale_freq_shift):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
@@ -94,7 +99,7 @@ class Timesteps(nn.Module):
class GaussianFourierProjection(nn.Module):
"""Gaussian Fourier embeddings for noise levels."""
def __init__(self, embedding_size=256, scale=1.0):
def __init__(self, embedding_size: int = 256, scale: float = 1.0):
super().__init__()
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)

View File

@@ -0,0 +1,56 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import flax.linen as nn
import jax.numpy as jnp
# This is like models.embeddings.get_timestep_embedding (PyTorch) but
# less general (only handles the case we currently need).
def get_sinusoidal_embeddings(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
:param timesteps: a 1-D tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
embeddings. :return: an [N x dim] tensor of positional embeddings.
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = jnp.exp(jnp.arange(half_dim) * -emb)
emb = timesteps[:, None] * emb[None, :]
emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1)
return emb
class FlaxTimestepEmbedding(nn.Module):
time_embed_dim: int = 32
dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, temb):
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb)
temb = nn.silu(temb)
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb)
return temb
class FlaxTimesteps(nn.Module):
dim: int = 32
@nn.compact
def __call__(self, timesteps):
return get_sinusoidal_embeddings(timesteps, self.dim)

View File

@@ -1,6 +1,5 @@
from functools import partial
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -134,10 +133,10 @@ class FirUpsample2D(nn.Module):
kernel = [1] * factor
# setup kernel
kernel = np.asarray(kernel, dtype=np.float32)
kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = np.outer(kernel, kernel)
kernel /= np.sum(kernel)
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)
kernel = kernel * (gain * (factor**2))
@@ -219,10 +218,10 @@ class FirDownsample2D(nn.Module):
kernel = [1] * factor
# setup kernel
kernel = np.asarray(kernel, dtype=np.float32)
kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = np.outer(kernel, kernel)
kernel /= np.sum(kernel)
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)
kernel = kernel * gain
@@ -265,7 +264,7 @@ class ResnetBlock2D(nn.Module):
time_embedding_norm="default",
kernel=None,
output_scale_factor=1.0,
use_nin_shortcut=None,
use_in_shortcut=None,
up=False,
down=False,
):
@@ -322,10 +321,10 @@ class ResnetBlock2D(nn.Module):
else:
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
self.conv_shortcut = None
if self.use_nin_shortcut:
if self.use_in_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb):
@@ -391,16 +390,14 @@ def upsample_2d(x, kernel=None, factor=2, gain=1):
if kernel is None:
kernel = [1] * factor
kernel = np.asarray(kernel, dtype=np.float32)
kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = np.outer(kernel, kernel)
kernel /= np.sum(kernel)
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)
kernel = kernel * (gain * (factor**2))
p = kernel.shape[0] - factor
return upfirdn2d_native(
x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
)
return upfirdn2d_native(x, kernel.to(device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
def downsample_2d(x, kernel=None, factor=2, gain=1):
@@ -425,14 +422,14 @@ def downsample_2d(x, kernel=None, factor=2, gain=1):
if kernel is None:
kernel = [1] * factor
kernel = np.asarray(kernel, dtype=np.float32)
kernel = torch.tensor(kernel, dtype=torch.float32)
if kernel.ndim == 1:
kernel = np.outer(kernel, kernel)
kernel /= np.sum(kernel)
kernel = torch.outer(kernel, kernel)
kernel /= torch.sum(kernel)
kernel = kernel * gain
p = kernel.shape[0] - factor
return upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
return upfirdn2d_native(x, kernel.to(device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
@@ -448,10 +445,15 @@ def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
# Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
if input.device.type == "mps":
out = out.to("cpu")
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out.to(input.device) # Move back to mps if necessary
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),

View File

@@ -0,0 +1,111 @@
import flax.linen as nn
import jax
import jax.numpy as jnp
class FlaxUpsample2D(nn.Module):
out_channels: int
dtype: jnp.dtype = jnp.float32
def setup(self):
self.conv = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)
def __call__(self, hidden_states):
batch, height, width, channels = hidden_states.shape
hidden_states = jax.image.resize(
hidden_states,
shape=(batch, height * 2, width * 2, channels),
method="nearest",
)
hidden_states = self.conv(hidden_states)
return hidden_states
class FlaxDownsample2D(nn.Module):
out_channels: int
dtype: jnp.dtype = jnp.float32
def setup(self):
self.conv = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
strides=(2, 2),
padding=((1, 1), (1, 1)), # padding="VALID",
dtype=self.dtype,
)
def __call__(self, hidden_states):
# pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
# hidden_states = jnp.pad(hidden_states, pad_width=pad)
hidden_states = self.conv(hidden_states)
return hidden_states
class FlaxResnetBlock2D(nn.Module):
in_channels: int
out_channels: int = None
dropout_prob: float = 0.0
use_nin_shortcut: bool = None
dtype: jnp.dtype = jnp.float32
def setup(self):
out_channels = self.in_channels if self.out_channels is None else self.out_channels
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
self.conv1 = nn.Conv(
out_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)
self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype)
self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
self.dropout = nn.Dropout(self.dropout_prob)
self.conv2 = nn.Conv(
out_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)
use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
self.conv_shortcut = None
if use_nin_shortcut:
self.conv_shortcut = nn.Conv(
out_channels,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
def __call__(self, hidden_states, temb, deterministic=True):
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = nn.swish(hidden_states)
hidden_states = self.conv1(hidden_states)
temb = self.time_emb_proj(nn.swish(temb))
temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
hidden_states = nn.swish(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
residual = self.conv_shortcut(residual)
return hidden_states + residual

View File

@@ -23,6 +23,38 @@ class UNet2DOutput(BaseOutput):
class UNet2DModel(ModelMixin, ConfigMixin):
r"""
UNet2DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)
Parameters:
sample_size (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
Input sample size.
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to :
obj:`False`): Whether to flip sin to cos for fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
types.
up_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(224, 448, 672, 896)`): Tuple of block output channels.
layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization.
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
"""
@register_to_config
def __init__(
self,
@@ -82,6 +114,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=attention_head_dim,
downsample_padding=downsample_padding,
)
@@ -119,6 +152,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
add_upsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=attention_head_dim,
)
self.up_blocks.append(up_block)
@@ -136,6 +170,17 @@ class UNet2DModel(ModelMixin, ConfigMixin):
timestep: Union[torch.Tensor, float, int],
return_dict: bool = True,
) -> Union[UNet2DOutput, Tuple]:
"""r
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
Returns:
[`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0

View File

@@ -23,6 +23,37 @@ class UNet2DConditionOutput(BaseOutput):
class UNet2DConditionModel(ModelMixin, ConfigMixin):
r"""
UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
and returns sample shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the models (such as downloading or saving, etc.)
Parameters:
sample_size (`int`, *optional*): The size of the input sample.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
The tuple of upsample blocks to use.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
"""
@register_to_config
def __init__(
self,
@@ -83,6 +114,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim,
downsample_padding=downsample_padding,
@@ -122,6 +154,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
add_upsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim,
)
@@ -162,6 +195,19 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
encoder_hidden_states: torch.Tensor,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
"""r
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0

View File

@@ -0,0 +1,258 @@
from typing import Tuple, Union
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from ..configuration_utils import ConfigMixin, flax_register_to_config
from ..modeling_flax_utils import FlaxModelMixin
from ..utils import BaseOutput
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from .unet_blocks_flax import (
FlaxCrossAttnDownBlock2D,
FlaxCrossAttnUpBlock2D,
FlaxDownBlock2D,
FlaxUNetMidBlock2DCrossAttn,
FlaxUpBlock2D,
)
@flax.struct.dataclass
class FlaxUNet2DConditionOutput(BaseOutput):
"""
Args:
sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
"""
sample: jnp.ndarray
@flax_register_to_config
class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
r"""
FlaxUNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a
timestep and returns sample shaped output.
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the models (such as downloading or saving, etc.)
Parameters:
sample_size (`int`, *optional*): The size of the input sample.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D",
"FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D"
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
The tuple of upsample blocks to use. The corresponding class names will be: "FlaxUpBlock2D",
"FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D"
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
cross_attention_dim (`int`, *optional*, defaults to 768): The dimension of the cross attention features.
dropout (`float`, *optional*, defaults to 0): Dropout probability for down, up and bottleneck blocks.
"""
sample_size: int = 32
in_channels: int = 4
out_channels: int = 4
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
)
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
layers_per_block: int = 2
attention_head_dim: int = 8
cross_attention_dim: int = 1280
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
# init input tensors
sample_shape = (1, self.sample_size, self.sample_size, self.in_channels)
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
timesteps = jnp.ones((1,), dtype=jnp.int32)
encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"]
def setup(self):
block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4
# input
self.conv_in = nn.Conv(
block_out_channels[0],
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)
# time
self.time_proj = FlaxTimesteps(block_out_channels[0])
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
# down
down_blocks = []
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(self.down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
if down_block_type == "CrossAttnDownBlock2D":
down_block = FlaxCrossAttnDownBlock2D(
in_channels=input_channel,
out_channels=output_channel,
dropout=self.dropout,
num_layers=self.layers_per_block,
attn_num_head_channels=self.attention_head_dim,
add_downsample=not is_final_block,
dtype=self.dtype,
)
else:
down_block = FlaxDownBlock2D(
in_channels=input_channel,
out_channels=output_channel,
dropout=self.dropout,
num_layers=self.layers_per_block,
add_downsample=not is_final_block,
dtype=self.dtype,
)
down_blocks.append(down_block)
self.down_blocks = down_blocks
# mid
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
dropout=self.dropout,
attn_num_head_channels=self.attention_head_dim,
dtype=self.dtype,
)
# up
up_blocks = []
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(self.up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
is_final_block = i == len(block_out_channels) - 1
if up_block_type == "CrossAttnUpBlock2D":
up_block = FlaxCrossAttnUpBlock2D(
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
num_layers=self.layers_per_block + 1,
attn_num_head_channels=self.attention_head_dim,
add_upsample=not is_final_block,
dropout=self.dropout,
dtype=self.dtype,
)
else:
up_block = FlaxUpBlock2D(
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
num_layers=self.layers_per_block + 1,
add_upsample=not is_final_block,
dropout=self.dropout,
dtype=self.dtype,
)
up_blocks.append(up_block)
prev_output_channel = output_channel
self.up_blocks = up_blocks
# out
self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5)
self.conv_out = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)
def __call__(
self,
sample,
timesteps,
encoder_hidden_states,
return_dict: bool = True,
train: bool = False,
) -> Union[FlaxUNet2DConditionOutput, Tuple]:
"""r
Args:
sample (`jnp.ndarray`): (channel, height, width) noisy inputs tensor
timestep (`jnp.ndarray` or `float` or `int`): timesteps
encoder_hidden_states (`jnp.ndarray`): (channel, height, width) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
plain tuple.
train (`bool`, *optional*, defaults to `False`):
Use deterministic functions and disable dropout when not training.
Returns:
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
"""
# 1. time
t_emb = self.time_proj(timesteps)
t_emb = self.time_embedding(t_emb)
# 2. pre-process
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
for down_block in self.down_blocks:
if isinstance(down_block, FlaxCrossAttnDownBlock2D):
sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
else:
sample, res_samples = down_block(sample, t_emb, deterministic=not train)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
# 5. up
for up_block in self.up_blocks:
res_samples = down_block_res_samples[-(self.layers_per_block + 1) :]
down_block_res_samples = down_block_res_samples[: -(self.layers_per_block + 1)]
if isinstance(up_block, FlaxCrossAttnUpBlock2D):
sample = up_block(
sample,
temb=t_emb,
encoder_hidden_states=encoder_hidden_states,
res_hidden_states_tuple=res_samples,
deterministic=not train,
)
else:
sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples, deterministic=not train)
# 6. post-process
sample = self.conv_norm_out(sample)
sample = nn.silu(sample)
sample = self.conv_out(sample)
if not return_dict:
return (sample,)
return FlaxUNet2DConditionOutput(sample=sample)

View File

@@ -31,6 +31,7 @@ def get_down_block(
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
resnet_groups=None,
cross_attention_dim=None,
downsample_padding=None,
):
@@ -44,6 +45,7 @@ def get_down_block(
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
)
elif down_block_type == "AttnDownBlock2D":
@@ -55,6 +57,7 @@ def get_down_block(
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
)
@@ -69,6 +72,7 @@ def get_down_block(
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
@@ -104,6 +108,7 @@ def get_down_block(
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
)
@@ -119,6 +124,7 @@ def get_up_block(
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
resnet_groups=None,
cross_attention_dim=None,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
@@ -132,6 +138,7 @@ def get_up_block(
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
)
elif up_block_type == "CrossAttnUpBlock2D":
if cross_attention_dim is None:
@@ -145,6 +152,7 @@ def get_up_block(
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
)
@@ -158,6 +166,7 @@ def get_up_block(
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
attn_num_head_channels=attn_num_head_channels,
)
elif up_block_type == "SkipUpBlock2D":
@@ -191,6 +200,7 @@ def get_up_block(
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
)
raise ValueError(f"{up_block_type} does not exist.")
@@ -323,6 +333,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
in_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
num_groups=resnet_groups,
)
)
resnets.append(
@@ -414,6 +425,7 @@ class AttnDownBlock2D(nn.Module):
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
num_groups=resnet_groups,
)
)
@@ -498,6 +510,7 @@ class CrossAttnDownBlock2D(nn.Module):
out_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
num_groups=resnet_groups,
)
)
self.attentions = nn.ModuleList(attentions)
@@ -807,7 +820,7 @@ class AttnSkipDownBlock2D(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_nin_shortcut=True,
use_in_shortcut=True,
down=True,
kernel="fir",
)
@@ -887,7 +900,7 @@ class SkipDownBlock2D(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_nin_shortcut=True,
use_in_shortcut=True,
down=True,
kernel="fir",
)
@@ -966,6 +979,7 @@ class AttnUpBlock2D(nn.Module):
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
num_groups=resnet_groups,
)
)
@@ -979,7 +993,6 @@ class AttnUpBlock2D(nn.Module):
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
@@ -1048,6 +1061,7 @@ class CrossAttnUpBlock2D(nn.Module):
out_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
num_groups=resnet_groups,
)
)
self.attentions = nn.ModuleList(attentions)
@@ -1075,7 +1089,6 @@ class CrossAttnUpBlock2D(nn.Module):
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
@@ -1139,7 +1152,6 @@ class UpBlock2D(nn.Module):
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
@@ -1343,7 +1355,7 @@ class AttnSkipUpBlock2D(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_nin_shortcut=True,
use_in_shortcut=True,
up=True,
kernel="fir",
)
@@ -1440,7 +1452,7 @@ class SkipUpBlock2D(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_nin_shortcut=True,
use_in_shortcut=True,
up=True,
kernel="fir",
)

View File

@@ -0,0 +1,263 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import flax.linen as nn
import jax.numpy as jnp
from .attention_flax import FlaxSpatialTransformer
from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
class FlaxCrossAttnDownBlock2D(nn.Module):
in_channels: int
out_channels: int
dropout: float = 0.0
num_layers: int = 1
attn_num_head_channels: int = 1
add_downsample: bool = True
dtype: jnp.dtype = jnp.float32
def setup(self):
resnets = []
attentions = []
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels
res_block = FlaxResnetBlock2D(
in_channels=in_channels,
out_channels=self.out_channels,
dropout_prob=self.dropout,
dtype=self.dtype,
)
resnets.append(res_block)
attn_block = FlaxSpatialTransformer(
in_channels=self.out_channels,
n_heads=self.attn_num_head_channels,
d_head=self.out_channels // self.attn_num_head_channels,
depth=1,
dtype=self.dtype,
)
attentions.append(attn_block)
self.resnets = resnets
self.attentions = attentions
if self.add_downsample:
self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
output_states += (hidden_states,)
if self.add_downsample:
hidden_states = self.downsample(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
class FlaxDownBlock2D(nn.Module):
in_channels: int
out_channels: int
dropout: float = 0.0
num_layers: int = 1
add_downsample: bool = True
dtype: jnp.dtype = jnp.float32
def setup(self):
resnets = []
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels
res_block = FlaxResnetBlock2D(
in_channels=in_channels,
out_channels=self.out_channels,
dropout_prob=self.dropout,
dtype=self.dtype,
)
resnets.append(res_block)
self.resnets = resnets
if self.add_downsample:
self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, temb, deterministic=True):
output_states = ()
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
output_states += (hidden_states,)
if self.add_downsample:
hidden_states = self.downsample(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
class FlaxCrossAttnUpBlock2D(nn.Module):
in_channels: int
out_channels: int
prev_output_channel: int
dropout: float = 0.0
num_layers: int = 1
attn_num_head_channels: int = 1
add_upsample: bool = True
dtype: jnp.dtype = jnp.float32
def setup(self):
resnets = []
attentions = []
for i in range(self.num_layers):
res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
res_block = FlaxResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=self.out_channels,
dropout_prob=self.dropout,
dtype=self.dtype,
)
resnets.append(res_block)
attn_block = FlaxSpatialTransformer(
in_channels=self.out_channels,
n_heads=self.attn_num_head_channels,
d_head=self.out_channels // self.attn_num_head_channels,
depth=1,
dtype=self.dtype,
)
attentions.append(attn_block)
self.resnets = resnets
self.attentions = attentions
if self.add_upsample:
self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
if self.add_upsample:
hidden_states = self.upsample(hidden_states)
return hidden_states
class FlaxUpBlock2D(nn.Module):
in_channels: int
out_channels: int
prev_output_channel: int
dropout: float = 0.0
num_layers: int = 1
add_upsample: bool = True
dtype: jnp.dtype = jnp.float32
def setup(self):
resnets = []
for i in range(self.num_layers):
res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
res_block = FlaxResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=self.out_channels,
dropout_prob=self.dropout,
dtype=self.dtype,
)
resnets.append(res_block)
self.resnets = resnets
if self.add_upsample:
self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True):
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
if self.add_upsample:
hidden_states = self.upsample(hidden_states)
return hidden_states
class FlaxUNetMidBlock2DCrossAttn(nn.Module):
in_channels: int
dropout: float = 0.0
num_layers: int = 1
attn_num_head_channels: int = 1
dtype: jnp.dtype = jnp.float32
def setup(self):
# there is always at least one resnet
resnets = [
FlaxResnetBlock2D(
in_channels=self.in_channels,
out_channels=self.in_channels,
dropout_prob=self.dropout,
dtype=self.dtype,
)
]
attentions = []
for _ in range(self.num_layers):
attn_block = FlaxSpatialTransformer(
in_channels=self.in_channels,
n_heads=self.attn_num_head_channels,
d_head=self.in_channels // self.attn_num_head_channels,
depth=1,
dtype=self.dtype,
)
attentions.append(attn_block)
res_block = FlaxResnetBlock2D(
in_channels=self.in_channels,
out_channels=self.in_channels,
dropout_prob=self.dropout,
dtype=self.dtype,
)
resnets.append(res_block)
self.resnets = resnets
self.attentions = attentions
def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
return hidden_states

View File

@@ -59,6 +59,7 @@ class Encoder(nn.Module):
down_block_types=("DownEncoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
double_z=True,
):
@@ -86,6 +87,7 @@ class Encoder(nn.Module):
resnet_eps=1e-6,
downsample_padding=0,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=None,
temb_channels=None,
)
@@ -99,13 +101,12 @@ class Encoder(nn.Module):
output_scale_factor=1,
resnet_time_scale_shift="default",
attn_num_head_channels=None,
resnet_groups=32,
resnet_groups=norm_num_groups,
temb_channels=None,
)
# out
num_groups_out = 32
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups_out, eps=1e-6)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
@@ -138,6 +139,7 @@ class Decoder(nn.Module):
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
):
super().__init__()
@@ -156,7 +158,7 @@ class Decoder(nn.Module):
output_scale_factor=1,
resnet_time_scale_shift="default",
attn_num_head_channels=None,
resnet_groups=32,
resnet_groups=norm_num_groups,
temb_channels=None,
)
@@ -178,6 +180,7 @@ class Decoder(nn.Module):
add_upsample=not is_final_block,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=None,
temb_channels=None,
)
@@ -185,8 +188,7 @@ class Decoder(nn.Module):
prev_output_channel = output_channel
# out
num_groups_out = 32
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=1e-6)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
@@ -338,7 +340,10 @@ class DiagonalGaussianDistribution(object):
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
x = self.mean + self.std * torch.randn(self.mean.shape, generator=generator, device=self.parameters.device)
device = self.parameters.device
sample_device = "cpu" if device.type == "mps" else device
sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device)
x = self.mean + self.std * sample
return x
def kl(self, other=None):
@@ -368,6 +373,27 @@ class DiagonalGaussianDistribution(object):
class VQModel(ModelMixin, ConfigMixin):
r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray
Kavukcuoglu.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(64,)`): Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): TODO
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
"""
@register_to_config
def __init__(
self,
@@ -381,6 +407,7 @@ class VQModel(ModelMixin, ConfigMixin):
latent_channels: int = 3,
sample_size: int = 32,
num_vq_embeddings: int = 256,
norm_num_groups: int = 32,
):
super().__init__()
@@ -392,6 +419,7 @@ class VQModel(ModelMixin, ConfigMixin):
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=False,
)
@@ -409,6 +437,7 @@ class VQModel(ModelMixin, ConfigMixin):
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
)
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
@@ -437,6 +466,12 @@ class VQModel(ModelMixin, ConfigMixin):
return DecoderOutput(sample=dec)
def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
h = self.encode(x).latents
dec = self.decode(h).sample
@@ -448,6 +483,26 @@ class VQModel(ModelMixin, ConfigMixin):
class AutoencoderKL(ModelMixin, ConfigMixin):
r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
and Max Welling.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(64,)`): Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): TODO
"""
@register_to_config
def __init__(
self,
@@ -459,6 +514,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32,
):
super().__init__()
@@ -471,6 +527,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=True,
)
@@ -481,6 +538,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
)
@@ -507,12 +565,24 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
return DecoderOutput(sample=dec)
def forward(
self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True
self,
sample: torch.FloatTensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample()
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample

187
src/diffusers/onnx_utils.py Normal file
View File

@@ -0,0 +1,187 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
from pathlib import Path
from typing import Optional, Union
import numpy as np
from huggingface_hub import hf_hub_download
from .utils import is_onnx_available, logging
if is_onnx_available():
import onnxruntime as ort
ONNX_WEIGHTS_NAME = "model.onnx"
logger = logging.get_logger(__name__)
class OnnxRuntimeModel:
def __init__(self, model=None, **kwargs):
logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.")
self.model = model
self.model_save_dir = kwargs.get("model_save_dir", None)
self.latest_model_name = kwargs.get("latest_model_name", "model.onnx")
def __call__(self, **kwargs):
inputs = {k: np.array(v) for k, v in kwargs.items()}
return self.model.run(None, inputs)
@staticmethod
def load_model(path: Union[str, Path], provider=None):
"""
Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
Arguments:
path (`str` or `Path`):
Directory from which to load
provider(`str`, *optional*):
Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider`
"""
if provider is None:
logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
provider = "CPUExecutionProvider"
return ort.InferenceSession(path, providers=[provider])
def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
[`~optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained`] class method. It will always save the
latest_model_name.
Arguments:
save_directory (`str` or `Path`):
Directory where to save the model file.
file_name(`str`, *optional*):
Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to save the
model with a different name.
"""
model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
src_path = self.model_save_dir.joinpath(self.latest_model_name)
dst_path = Path(save_directory).joinpath(model_file_name)
if not src_path.samefile(dst_path):
shutil.copyfile(src_path, dst_path)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
**kwargs,
):
"""
Save a model to a directory, so that it can be re-loaded using the [`~OnnxModel.from_pretrained`] class
method.:
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to which to save. Will be created if it doesn't exist.
"""
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
os.makedirs(save_directory, exist_ok=True)
# saving model weights/files
self._save_pretrained(save_directory, **kwargs)
@classmethod
def _from_pretrained(
cls,
model_id: Union[str, Path],
use_auth_token: Optional[Union[bool, str, None]] = None,
revision: Optional[Union[str, None]] = None,
force_download: bool = False,
cache_dir: Optional[str] = None,
file_name: Optional[str] = None,
provider: Optional[str] = None,
**kwargs,
):
"""
Load a model from a directory or the HF Hub.
Arguments:
model_id (`str` or `Path`):
Directory from which to load
use_auth_token (`str` or `bool`):
Is needed to load models from a private or gated repository
revision (`str`):
Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id
cache_dir (`Union[str, Path]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
file_name(`str`):
Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load
different model files from the same repository or directory.
provider(`str`):
The ONNX runtime provider, e.g. `CPUExecutionProvider` or `CUDAExecutionProvider`.
kwargs (`Dict`, *optional*):
kwargs will be passed to the model during initialization
"""
model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
# load model from local directory
if os.path.isdir(model_id):
model = OnnxRuntimeModel.load_model(os.path.join(model_id, model_file_name), provider=provider)
kwargs["model_save_dir"] = Path(model_id)
# load model from hub
else:
# download model
model_cache_path = hf_hub_download(
repo_id=model_id,
filename=model_file_name,
use_auth_token=use_auth_token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
)
kwargs["model_save_dir"] = Path(model_cache_path).parent
kwargs["latest_model_name"] = Path(model_cache_path).name
model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider)
return cls(model=model, **kwargs)
@classmethod
def from_pretrained(
cls,
model_id: Union[str, Path],
force_download: bool = True,
use_auth_token: Optional[str] = None,
cache_dir: Optional[str] = None,
**model_kwargs,
):
revision = None
if len(str(model_id).split("@")) == 2:
model_id, revision = model_id.split("@")
return cls._from_pretrained(
model_id=model_id,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
use_auth_token=use_auth_token,
**model_kwargs,
)

View File

@@ -23,13 +23,17 @@ from typing import List, Optional, Union
import numpy as np
import torch
import diffusers
import PIL
from huggingface_hub import snapshot_download
from PIL import Image
from tqdm.auto import tqdm
from .configuration_utils import ConfigMixin
from .utils import DIFFUSERS_CACHE, BaseOutput, logging
from .modeling_utils import WEIGHTS_NAME
from .onnx_utils import ONNX_WEIGHTS_NAME
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, logging
INDEX_FILE = "diffusion_pytorch_model.bin"
@@ -43,6 +47,7 @@ LOADABLE_CLASSES = {
"ModelMixin": ["save_pretrained", "from_pretrained"],
"SchedulerMixin": ["save_config", "from_config"],
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
},
"transformers": {
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
@@ -72,7 +77,20 @@ class ImagePipelineOutput(BaseOutput):
class DiffusionPipeline(ConfigMixin):
r"""
Base class for all models.
[`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines
and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to:
- move all PyTorch modules to the device of your choice
- enabling/disabling the progress bar for the denoising iteration
Class attributes:
- **config_name** ([`str`]) -- name of the config file that will store the class and module names of all
components of the diffusion pipeline.
"""
config_name = "model_index.json"
def register_modules(self, **kwargs):
@@ -80,7 +98,7 @@ class DiffusionPipeline(ConfigMixin):
from diffusers import pipelines
for name, module in kwargs.items():
# retrive library
# retrieve library
library = module.__module__.split(".")[0]
# check if the module is a pipeline module
@@ -94,7 +112,7 @@ class DiffusionPipeline(ConfigMixin):
if library not in LOADABLE_CLASSES or is_pipeline_module:
library = pipeline_dir
# retrive class_name
# retrieve class_name
class_name = module.__class__.__name__
register_dict = {name: (library, class_name)}
@@ -106,6 +124,15 @@ class DiffusionPipeline(ConfigMixin):
setattr(self, name, module)
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
"""
Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to which to save. Will be created if it doesn't exist.
"""
self.save_config(save_directory)
model_index_dict = dict(self.config)
@@ -146,6 +173,10 @@ class DiffusionPipeline(ConfigMixin):
@property
def device(self) -> torch.device:
r"""
Returns:
`torch.device`: The torch device on which the pipeline is located.
"""
module_names, _ = self.extract_init_dict(dict(self.config))
for name in module_names.keys():
module = getattr(self, name)
@@ -156,7 +187,94 @@ class DiffusionPipeline(ConfigMixin):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r"""
Add docstrings
Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
task.
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
weights are discarded.
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A string, the *repo id* of a pretrained pipeline hosted inside a model repo on
https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like
`CompVis/ldm-text2im-large-256`.
- A path to a *directory* containing pipeline weights saved using
[`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
will be automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
mirror (`str`, *optional*):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information. specify the folder name here.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
specific pipeline class. The overritten components are then directly passed to the pipelines `__init__`
method. See example below for more information.
<Tip>
Passing `use_auth_token=True`` is required when you want to use a private model, *e.g.*
`"CompVis/stable-diffusion-v1-4"`
</Tip>
<Tip>
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
this method in a firewalled environment.
</Tip>
Examples:
```py
>>> from diffusers import DiffusionPipeline
>>> # Download pipeline from huggingface.co and cache.
>>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
>>> # Download pipeline that requires an authorization token
>>> # For more information on access tokens, please refer to this section
>>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
>>> pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
>>> # Download pipeline, but overwrite scheduler
>>> from diffusers import LMSDiscreteScheduler
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
>>> pipeline = DiffusionPipeline.from_pretrained(
... "CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True
... )
```
"""
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
@@ -165,10 +283,26 @@ class DiffusionPipeline(ConfigMixin):
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
provider = kwargs.pop("provider", None)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path):
config_dict = cls.get_config_dict(
pretrained_model_name_or_path,
cache_dir=cache_dir,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
)
# make sure we only download sub-folders and `diffusers` filenames
folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
# download all allow_patterns
cached_folder = snapshot_download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
@@ -177,6 +311,7 @@ class DiffusionPipeline(ConfigMixin):
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
allow_patterns=allow_patterns,
)
else:
cached_folder = pretrained_model_name_or_path
@@ -259,6 +394,8 @@ class DiffusionPipeline(ConfigMixin):
loading_kwargs = {}
if issubclass(class_obj, torch.nn.Module):
loading_kwargs["torch_dtype"] = torch_dtype
if issubclass(class_obj, diffusers.OnnxRuntimeModel):
loading_kwargs["provider"] = provider
# check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)):

View File

@@ -40,8 +40,8 @@ available a colab notebook to directly try them out.
| [score_sde_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* |
| [score_sde_vp](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* |
| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stochastic_karras_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | *Unconditional Image Generation* |
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.
@@ -70,7 +70,7 @@ not be used for training. If you want to store the gradients during the forward
## Contribution
We are more than happy about any contribution to the offically supported pipelines 🤗. We aspire
We are more than happy about any contribution to the officially supported pipelines 🤗. We aspire
all of our pipelines to be **self-contained**, **easy-to-tweak**, **beginner-friendly** and for **one-purpose-only**.
- **Self-contained**: A pipeline shall be as self-contained as possible. More specifically, this means that all functionality should be either directly defined in the pipeline file iteslf, should be inherited from (and only from) the [`DiffusionPipeline` class](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L56) or be directly attached to the model and scheduler components of the pipeline.
@@ -134,7 +134,7 @@ with autocast("cuda"):
images[0].save("fantasy_landscape.png")
```
You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb)
You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
### Tweak prompts reusing seeds and latents
@@ -179,4 +179,4 @@ with autocast("cuda"):
images[0].save("cat_on_bench.png")
```
You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/in_painting_with_stable_diffusion_using_diffusers.ipynb)
You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)

View File

@@ -1,4 +1,4 @@
from ..utils import is_transformers_available
from ..utils import is_onnx_available, is_transformers_available
from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline
from .latent_diffusion_uncond import LDMPipeline
@@ -14,3 +14,6 @@ if is_transformers_available():
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
)
if is_transformers_available() and is_onnx_available():
from .stable_diffusion import StableDiffusionOnnxPipeline

View File

@@ -52,21 +52,26 @@ class DDIMPipeline(DiffusionPipeline):
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Args:
batch_size (:obj:`int`, *optional*, defaults to 1):
batch_size (`int`, *optional*, defaults to 1):
The number of images to generate.
generator (:obj:`torch.Generator`, *optional*):
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
eta (:obj:`float`, *optional*, defaults to 0.0):
eta (`float`, *optional*, defaults to 0.0):
The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM).
num_inference_steps (:obj:`int`, *optional*, defaults to 50):
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
output_type (:obj:`str`, *optional*, defaults to :obj:`"pil"`):
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
return_dict (:obj:`bool`, *optional*, defaults to :obj:`True`):
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
Returns:
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
if "torch_device" in kwargs:

View File

@@ -50,16 +50,21 @@ class DDPMPipeline(DiffusionPipeline):
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Args:
batch_size (:obj:`int`, *optional*, defaults to 1):
batch_size (`int`, *optional*, defaults to 1):
The number of images to generate.
generator (:obj:`torch.Generator`, *optional*):
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
output_type (:obj:`str`, *optional*, defaults to :obj:`"pil"`):
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
return_dict (:obj:`bool`, *optional*, defaults to :obj:`True`):
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
Returns:
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")

View File

@@ -85,9 +85,14 @@ class LDMTextToImagePipeline(DiffusionPipeline):
deterministic.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*):
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
Returns:
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
@@ -686,7 +691,6 @@ class LDMBertModel(LDMBertPreTrainedModel):
output_hidden_states=None,
return_dict=None,
):
outputs = self.model(
input_ids,
attention_mask=attention_mask,

View File

@@ -38,7 +38,6 @@ class LDMPipeline(DiffusionPipeline):
return_dict: bool = True,
**kwargs,
) -> Union[Tuple, ImagePipelineOutput]:
r"""
Args:
batch_size (`int`, *optional*, defaults to 1):
@@ -51,9 +50,14 @@ class LDMPipeline(DiffusionPipeline):
expense of slower inference.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
Returns:
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
if "torch_device" in kwargs:

View File

@@ -30,7 +30,7 @@ class PNDMPipeline(DiffusionPipeline):
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Parameters:
unet (:obj:`UNet2DModel`): U-Net architecture to denoise the encoded image latents.
unet (`UNet2DModel`): U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
The `PNDMScheduler` to be used in combination with `unet` to denoise the encoded image.
"""
@@ -55,20 +55,22 @@ class PNDMPipeline(DiffusionPipeline):
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Args:
batch_size (:obj:`int`, `optional`, defaults to 1): The number of images to generate.
num_inference_steps (:
obj:`int`, `optional`, defaults to 50): The number of denoising steps. More denoising steps usually
lead to a higher quality image at the expense of slower inference.
generator (:
obj:`torch.Generator`, `optional`): A [torch
batch_size (`int`, `optional`, defaults to 1): The number of images to generate.
num_inference_steps (`int`, `optional`, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
generator (`torch.Generator`, `optional`): A [torch
generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
output_type (:
obj:`str`, `optional`, defaults to :obj:`"pil"`): The output format of the generate image. Choose
between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
return_dict (:
obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not to return a
output_type (`str`, `optional`, defaults to `"pil"`): The output format of the generate image. Choose
between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, `optional`, defaults to `True`): Whether or not to return a
[`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
Returns:
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
# For more information on the sampling method you can take a look at Algorithm 2 of
# the official paper: https://arxiv.org/pdf/2202.09778.pdf

View File

@@ -36,16 +36,21 @@ class ScoreSdeVePipeline(DiffusionPipeline):
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Args:
batch_size (:obj:`int`, *optional*, defaults to 1):
batch_size (`int`, *optional*, defaults to 1):
The number of images to generate.
generator (:obj:`torch.Generator`, *optional*):
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
output_type (:obj:`str`, *optional*, defaults to :obj:`"pil"`):
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
return_dict (:obj:`bool`, *optional*, defaults to :obj:`True`):
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
Returns:
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
if "torch_device" in kwargs:
@@ -75,7 +80,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device)
# correction step
for _ in range(self.scheduler.correct_steps):
for _ in range(self.scheduler.config.correct_steps):
model_output = self.unet(sample, sigma_t).sample
sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample

View File

@@ -12,7 +12,7 @@ The summary of the model is the following:
- Stable Diffusion has the same architecture as [Latent Diffusion](https://arxiv.org/abs/2112.10752) but uses a frozen CLIP Text Encoder instead of training the text encoder jointly with the diffusion model.
- An in-detail explanation of the Stable Diffusion model can be found under [Stable Diffusion with 🧨 Diffusers](https://huggingface.co/blog/stable_diffusion).
- If you don't want to rely on the Hugging Face Hub and having to pass a authentification token, you can
- If you don't want to rely on the Hugging Face Hub and having to pass a authentication token, you can
download the weights with `git lfs install; git clone https://huggingface.co/CompVis/stable-diffusion-v1-4` and instead pass the local path to the cloned folder to `from_pretrained` as shown below.
- Stable Diffusion can work with a variety of different samplers as is shown below.
@@ -21,8 +21,8 @@ download the weights with `git lfs install; git clone https://huggingface.co/Com
| Pipeline | Tasks | Colab
|---|---|:---:|
| [pipeline_stable_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [pipeline_stable_diffusion_img2img](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb)
| [pipeline_stable_diffusion_inpaint](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [pipeline_stable_diffusion_img2img](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [pipeline_stable_diffusion_inpaint](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
## Examples:

View File

@@ -1,4 +1,3 @@
# flake8: noqa
from dataclasses import dataclass
from typing import List, Union
@@ -7,7 +6,7 @@ import numpy as np
import PIL
from PIL import Image
from ...utils import BaseOutput, is_transformers_available
from ...utils import BaseOutput, is_onnx_available, is_transformers_available
@dataclass
@@ -33,3 +32,6 @@ if is_transformers_available():
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from .safety_checker import StableDiffusionSafetyChecker
if is_transformers_available() and is_onnx_available():
from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline

View File

@@ -6,6 +6,7 @@ import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
@@ -53,6 +54,21 @@ class StableDiffusionPipeline(DiffusionPipeline):
):
super().__init__()
scheduler = scheduler.set_format("pt")
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn(
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file",
DeprecationWarning,
)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -136,13 +152,14 @@ class StableDiffusionPipeline(DiffusionPipeline):
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
Returns:
`~pipelines.stable_diffusion.StableDiffusionPipelineOutput` if `return_dict` is True, otherwise a tuple.
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
@@ -198,29 +215,29 @@ class StableDiffusionPipeline(DiffusionPipeline):
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# get the initial random noise unless the user supplied it
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_device = "cpu" if self.device.type == "mps" else self.device
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
if latents is None:
latents = torch.randn(
latents_shape,
generator=generator,
device=self.device,
device=latents_device,
dtype=text_embeddings.dtype,
)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device)
latents = latents.to(latents_device)
# set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
extra_set_kwargs = {}
if accepts_offset:
extra_set_kwargs["offset"] = 1
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.timesteps = self.scheduler.timesteps.to(latents_device)
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
self.scheduler.timesteps = torch.tensor(self.scheduler.timesteps, device=self.device)
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]

View File

@@ -1,4 +1,5 @@
import inspect
import warnings
from typing import List, Optional, Union
import numpy as np
@@ -7,6 +8,7 @@ import torch
import PIL
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
@@ -64,6 +66,21 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
):
super().__init__()
scheduler = scheduler.set_format("pt")
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn(
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file",
DeprecationWarning,
)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -99,7 +116,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
back to computing attention in one step.
"""
# set slice_size = `None` to disable `set_attention_slice`
self.enable_attention_slice(None)
self.enable_attention_slicing(None)
@torch.no_grad()
def __call__(
@@ -146,13 +163,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
deterministic.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
Returns:
`~pipelines.stable_diffusion.StableDiffusionPipelineOutput` if `return_dict` is True, otherwise a tuple.
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
@@ -168,16 +186,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
# set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
extra_set_kwargs = {}
offset = 0
if accepts_offset:
offset = 1
extra_set_kwargs["offset"] = 1
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
if not isinstance(init_image, torch.FloatTensor):
if isinstance(init_image, PIL.Image.Image):
init_image = preprocess(init_image)
# encode the init image into latents and scale the latents
@@ -189,6 +200,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
init_latents = torch.cat([init_latents] * batch_size)
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
if isinstance(self.scheduler, LMSDiscreteScheduler):
@@ -248,7 +260,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[t_index]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS

View File

@@ -1,4 +1,5 @@
import inspect
import warnings
from typing import List, Optional, Union
import numpy as np
@@ -8,13 +9,18 @@ import PIL
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, PNDMScheduler
from ...utils import logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__)
def preprocess_image(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
@@ -78,6 +84,22 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
):
super().__init__()
scheduler = scheduler.set_format("pt")
logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn(
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file",
DeprecationWarning,
)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -113,7 +135,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
back to computing attention in one step.
"""
# set slice_size = `None` to disable `set_attention_slice`
self.enable_attention_slice(None)
self.enable_attention_slicing(None)
@torch.no_grad()
def __call__(
@@ -140,8 +162,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
process. This is the image whose masked region will be inpainted.
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
replaced by noise and therefore repainted, while black pixels will be preserved. The mask image will be
converted to a single channel (luminance) before use.
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
is 1, the denoising process will be run on the masked area for the full number of iterations specified
@@ -164,13 +187,14 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
deterministic.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
Returns:
`~pipelines.stable_diffusion.StableDiffusionPipelineOutput` if `return_dict` is True, otherwise a tuple.
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
@@ -186,20 +210,15 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
# set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
extra_set_kwargs = {}
offset = 0
if accepts_offset:
offset = 1
extra_set_kwargs["offset"] = 1
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
self.scheduler.set_timesteps(num_inference_steps)
# preprocess image
init_image = preprocess_image(init_image).to(self.device)
if not isinstance(init_image, torch.FloatTensor):
init_image = preprocess_image(init_image)
init_image = init_image.to(self.device)
# encode the init image into latents and scale the latents
init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
@@ -209,14 +228,17 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
init_latents_orig = init_latents
# preprocess mask
mask = preprocess_mask(mask_image).to(self.device)
mask = torch.cat([mask] * batch_size)
if not isinstance(mask_image, torch.FloatTensor):
mask_image = preprocess_mask(mask_image)
mask_image = mask_image.to(self.device)
mask = torch.cat([mask_image] * batch_size)
# check sizes
if not mask.shape == init_latents.shape:
raise ValueError("The mask and init_image should be the same size!")
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep]

View File

@@ -0,0 +1,160 @@
import inspect
from typing import List, Optional, Union
import numpy as np
from transformers import CLIPFeatureExtractor, CLIPTokenizer
from ...onnx_utils import OnnxRuntimeModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from . import StableDiffusionPipelineOutput
class StableDiffusionOnnxPipeline(DiffusionPipeline):
vae_decoder: OnnxRuntimeModel
text_encoder: OnnxRuntimeModel
tokenizer: CLIPTokenizer
unet: OnnxRuntimeModel
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor
def __init__(
self,
vae_decoder: OnnxRuntimeModel,
text_encoder: OnnxRuntimeModel,
tokenizer: CLIPTokenizer,
unet: OnnxRuntimeModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
scheduler = scheduler.set_format("np")
self.register_modules(
vae_decoder=vae_decoder,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
def __call__(
self,
prompt: Union[str, List[str]],
height: Optional[int] = 512,
width: Optional[int] = 512,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
eta: Optional[float] = 0.0,
latents: Optional[np.ndarray] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
):
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# get prompt text embeddings
text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
text_embeddings = self.text_encoder(input_ids=text_input.input_ids.astype(np.int32))[0]
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
max_length = text_input.input_ids.shape[-1]
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
)
uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
# get the initial random noise unless the user supplied it
latents_shape = (batch_size, 4, height // 8, width // 8)
if latents is None:
latents = np.random.randn(*latents_shape).astype(np.float32)
elif latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[i]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
# predict the noise residual
noise_pred = self.unet(
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings
)
noise_pred = noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae_decoder(latent_sample=latents)[0]
image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))
# run safety checker
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image)
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

View File

@@ -13,7 +13,7 @@ logger = logging.get_logger(__name__)
def cosine_distance(image_embeds, text_embeds):
normalized_image_embeds = nn.functional.normalize(image_embeds)
normalized_text_embeds = nn.functional.normalize(text_embeds)
return torch.mm(normalized_image_embeds, normalized_text_embeds.T)
return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
class StableDiffusionSafetyChecker(PreTrainedModel):
@@ -78,3 +78,29 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
)
return images, has_nsfw_concepts
@torch.inference_mode()
def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
pooled_output = self.vision_model(clip_input)[1] # pooled_output
image_embeds = self.visual_projection(pooled_output)
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
cos_dist = cosine_distance(image_embeds, self.concept_embeds)
# increase this value to create a stronger `nsfw` filter
# at the cost of increasing the possibility of filtering benign images
adjustment = 0.0
special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment
# special_scores = special_scores.round(decimals=3)
special_care = torch.any(special_scores > 0, dim=1)
special_adjustment = special_care * 0.01
special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1])
concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
# concept_scores = concept_scores.round(decimals=3)
has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
images[has_nsfw_concepts] = 0.0 # black image
return images, has_nsfw_concepts

View File

@@ -45,19 +45,24 @@ class KarrasVePipeline(DiffusionPipeline):
) -> Union[Tuple, ImagePipelineOutput]:
r"""
Args:
batch_size (:obj:`int`, *optional*, defaults to 1):
batch_size (`int`, *optional*, defaults to 1):
The number of images to generate.
generator (:obj:`torch.Generator`, *optional*):
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
num_inference_steps (:obj:`int`, *optional*, defaults to 50):
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
output_type (:obj:`str`, *optional*, defaults to :obj:`"pil"`):
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
return_dict (:obj:`bool`, *optional*, defaults to :obj:`True`):
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
Returns:
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")

View File

@@ -1,7 +1,7 @@
# Schedulers
- Schedulers are the algorithms to use diffusion models in inference as well as for training. They include the noise schedules and define algorithm-specific diffusion steps.
- Schedulers can be used interchangable between diffusion models in inference to find the preferred trade-off between speed and generation quality.
- Schedulers can be used interchangeable between diffusion models in inference to find the preferred trade-off between speed and generation quality.
- Schedulers are available in numpy, but can easily be transformed into PyTorch.
## API

View File

@@ -12,17 +12,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import is_scipy_available
from .scheduling_ddim import DDIMScheduler
from .scheduling_ddpm import DDPMScheduler
from .scheduling_karras_ve import KarrasVeScheduler
from .scheduling_pndm import PNDMScheduler
from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_sde_vp import ScoreSdeVpScheduler
from .scheduling_utils import SchedulerMixin
from ..utils import is_flax_available, is_scipy_available, is_torch_available
if is_torch_available():
from .scheduling_ddim import DDIMScheduler
from .scheduling_ddpm import DDPMScheduler
from .scheduling_karras_ve import KarrasVeScheduler
from .scheduling_pndm import PNDMScheduler
from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_sde_vp import ScoreSdeVpScheduler
from .scheduling_utils import SchedulerMixin
else:
from ..utils.dummy_pt_objects import * # noqa F403
if is_flax_available():
from .scheduling_ddim_flax import FlaxDDIMScheduler
from .scheduling_ddpm_flax import FlaxDDPMScheduler
from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
from .scheduling_pndm_flax import FlaxPNDMScheduler
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
else:
from ..utils.dummy_flax_objects import * # noqa F403
if is_scipy_available():
from .scheduling_lms_discrete import LMSDiscreteScheduler
else:
from ..utils.dummy_scipy_objects import * # noqa F403
from ..utils.dummy_torch_and_scipy_objects import * # noqa F403

View File

@@ -16,6 +16,7 @@
# and https://github.com/hojonathanho/diffusion
import math
import warnings
from typing import Optional, Tuple, Union
import numpy as np
@@ -59,6 +60,11 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
diffusion probabilistic models (DDPMs) with non-Markovian guidance.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
For more details, see the original paper: https://arxiv.org/abs/2010.02502
Args:
@@ -68,12 +74,18 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, optional): TODO
timestep_values (`np.ndarray`, optional): TODO
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
set_alpha_to_one (`bool`, default `True`):
if alpha for final step is 1 or the final alpha of the "non-previous" one.
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the value of alpha at step 0.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
"""
@@ -86,9 +98,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
timestep_values: Optional[np.ndarray] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
tensor_format: str = "pt",
):
if trained_betas is not None:
@@ -109,7 +121,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this paratemer simply to one or
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
@@ -130,19 +142,31 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return variance
def set_timesteps(self, num_inference_steps: int, offset: int = 0):
def set_timesteps(self, num_inference_steps: int, **kwargs):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
offset (`int`): TODO
"""
offset = self.config.steps_offset
if "offset" in kwargs:
warnings.warn(
"`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
" Please pass `steps_offset` to `__init__` instead.",
DeprecationWarning,
)
offset = kwargs["offset"]
self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
)[::-1].copy()
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
self.timesteps += offset
self.set_format(tensor_format=self.tensor_format)
@@ -171,7 +195,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
`SchedulerOutput`: updated sample in the diffusion chain.
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
if self.num_inference_steps is None:
@@ -187,7 +213,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# - pred_original_sample -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_sample_direction -> "direction pointingc to x_t"
# - pred_sample_direction -> "direction pointing to x_t"
# - pred_prev_sample -> "x_t-1"
# 1. get previous step value (=t-1)
@@ -242,6 +268,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
) -> Union[torch.FloatTensor, np.ndarray]:
if self.tensor_format == "pt":
timesteps = timesteps.to(self.alphas_cumprod.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5

View File

@@ -0,0 +1,274 @@
# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import flax
import jax.numpy as jnp
from jax import random
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
Returns:
betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return jnp.array(betas, dtype=jnp.float32)
@flax.struct.dataclass
class DDIMSchedulerState:
# setable values
timesteps: jnp.ndarray
num_inference_steps: Optional[int] = None
@classmethod
def create(cls, num_train_timesteps: int):
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1])
@dataclass
class FlaxSchedulerOutput(SchedulerOutput):
state: DDIMSchedulerState
class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
"""
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
diffusion probabilistic models (DDPMs) with non-Markovian guidance.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
For more details, see the original paper: https://arxiv.org/abs/2010.02502
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`jnp.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
set_alpha_to_one (`bool`, default `True`):
if alpha for final step is 1 or the final alpha of the "non-previous" one.
"""
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[jnp.ndarray] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
):
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
if beta_schedule == "linear":
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
self.state = DDIMSchedulerState.create(num_train_timesteps=num_train_timesteps)
def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
return variance
def set_timesteps(
self, state: DDIMSchedulerState, num_inference_steps: int, offset: int = 0
) -> DDIMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
state (`DDIMSchedulerState`):
the `FlaxDDIMScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
offset (`int`):
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
"""
step_ratio = self.config.num_train_timesteps // num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1]
timesteps = timesteps + offset
return state.replace(num_inference_steps=num_inference_steps, timesteps=timesteps)
def step(
self,
state: DDIMSchedulerState,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
key: random.KeyArray,
eta: float = 0.0,
use_clipped_model_output: bool = False,
return_dict: bool = True,
) -> Union[FlaxSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
state (`DDIMSchedulerState`): the `FlaxDDIMScheduler` state data class instance.
model_output (`jnp.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
key (`random.KeyArray`): a PRNG key.
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
"""
if state.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_sample -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_sample_direction -> "direction pointing to x_t"
# - pred_prev_sample -> "x_t-1"
# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
# 4. Clip "predicted x_0"
if self.config.clip_sample:
pred_original_sample = jnp.clip(pred_original_sample, -1, 1)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 α_t1)/(1 α_t)) * sqrt(1 α_t/α_t1)
variance = self._get_variance(timestep, prev_timestep)
std_dev_t = eta * variance ** (0.5)
if use_clipped_model_output:
# the model_output is always re-derived from the clipped x_0 in Glide
model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if eta > 0:
key = random.split(key, num=1)
noise = random.normal(key=key, shape=model_output.shape)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
prev_sample = prev_sample + variance
if not return_dict:
return (prev_sample, state)
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
def add_noise(
self,
original_samples: jnp.ndarray,
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -58,6 +58,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
Langevin dynamics sampling.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
For more details, see the original paper: https://arxiv.org/abs/2006.11239
Args:
@@ -67,7 +72,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, optional): TODO
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
variance_type (`str`):
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
@@ -89,7 +95,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
clip_sample: bool = True,
tensor_format: str = "pt",
):
if trained_betas is not None:
self.betas = np.asarray(trained_betas)
elif beta_schedule == "linear":
@@ -143,7 +148,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
if variance_type is None:
variance_type = self.config.variance_type
# hacks - were probs added for training stability
# hacks - were probably added for training stability
if variance_type == "fixed_small":
variance = self.clip(variance, min_value=1e-20)
# for rl-diffuser https://arxiv.org/abs/2205.09991
@@ -182,14 +187,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor` or `np.ndarray`):
current instance of sample being created by diffusion process.
eta (`float`): weight of noise for added noise in diffusion step.
predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
`SchedulerOutput`: updated sample in the diffusion chain.
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
t = timestep
@@ -244,7 +250,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
) -> Union[torch.FloatTensor, np.ndarray]:
if self.tensor_format == "pt":
timesteps = timesteps.to(self.alphas_cumprod.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5

View File

@@ -0,0 +1,277 @@
# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import flax
import jax.numpy as jnp
from jax import random
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
Returns:
betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return jnp.array(betas, dtype=jnp.float32)
@flax.struct.dataclass
class DDPMSchedulerState:
# setable values
timesteps: jnp.ndarray
num_inference_steps: Optional[int] = None
@classmethod
def create(cls, num_train_timesteps: int):
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1])
@dataclass
class FlaxSchedulerOutput(SchedulerOutput):
state: DDPMSchedulerState
class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
"""
Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
Langevin dynamics sampling.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
For more details, see the original paper: https://arxiv.org/abs/2006.11239
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
variance_type (`str`):
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
"""
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[jnp.ndarray] = None,
variance_type: str = "fixed_small",
clip_sample: bool = True,
):
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
elif beta_schedule == "linear":
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
self.one = jnp.array(1.0)
self.state = DDPMSchedulerState.create(num_train_timesteps=num_train_timesteps)
self.variance_type = variance_type
def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int) -> DDPMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
state (`DDIMSchedulerState`):
the `FlaxDDPMScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
timesteps = jnp.arange(
0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps
)[::-1]
return state.replace(num_inference_steps=num_inference_steps, timesteps=timesteps)
def _get_variance(self, t, predicted_variance=None, variance_type=None):
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# and sample from it to get previous sample
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
if variance_type is None:
variance_type = self.config.variance_type
# hacks - were probably added for training stability
if variance_type == "fixed_small":
variance = jnp.clip(variance, a_min=1e-20)
# for rl-diffuser https://arxiv.org/abs/2205.09991
elif variance_type == "fixed_small_log":
variance = jnp.log(jnp.clip(variance, a_min=1e-20))
elif variance_type == "fixed_large":
variance = self.betas[t]
elif variance_type == "fixed_large_log":
# Glide max_log
variance = jnp.log(self.betas[t])
elif variance_type == "learned":
return predicted_variance
elif variance_type == "learned_range":
min_log = variance
max_log = self.betas[t]
frac = (predicted_variance + 1) / 2
variance = frac * max_log + (1 - frac) * min_log
return variance
def step(
self,
state: DDPMSchedulerState,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
key: random.KeyArray,
predict_epsilon: bool = True,
return_dict: bool = True,
) -> Union[FlaxSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
state (`DDPMSchedulerState`): the `FlaxDDPMScheduler` state data class instance.
model_output (`jnp.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
key (`random.KeyArray`): a PRNG key.
predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
"""
t = timestep
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1)
else:
predicted_variance = None
# 1. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if predict_epsilon:
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
else:
pred_original_sample = model_output
# 3. Clip "predicted x_0"
if self.config.clip_sample:
pred_original_sample = jnp.clip(pred_original_sample, -1, 1)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
# 5. Compute predicted previous sample µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
# 6. Add noise
variance = 0
if t > 0:
key = random.split(key, num=1)
noise = random.normal(key=key, shape=model_output.shape)
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
pred_prev_sample = pred_prev_sample + variance
if not return_dict:
return (pred_prev_sample, state)
return FlaxSchedulerOutput(prev_sample=pred_prev_sample, state=state)
def add_noise(
self,
original_samples: jnp.ndarray,
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -34,7 +34,7 @@ class KarrasVeOutput(BaseOutput):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Derivate of predicted original image sample (x_0).
Derivative of predicted original image sample (x_0).
"""
prev_sample: torch.FloatTensor
@@ -50,6 +50,11 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
differential equations." https://arxiv.org/abs/2011.13456
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
@@ -100,7 +105,10 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
self.schedule = [
(self.sigma_max * (self.sigma_min**2 / self.sigma_max**2) ** (i / (num_inference_steps - 1)))
(
self.config.sigma_max
* (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
)
for i in self.timesteps
]
self.schedule = np.array(self.schedule, dtype=np.float32)
@@ -116,13 +124,13 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
TODO Args:
"""
if self.s_min <= sigma <= self.s_max:
gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1)
if self.config.s_min <= sigma <= self.config.s_max:
gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1)
else:
gamma = 0
# sample eps ~ N(0, S_noise^2 * I)
eps = self.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device)
eps = self.config.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device)
sigma_hat = sigma + gamma * sigma
sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
@@ -147,8 +155,11 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check).
Returns:
[`~schedulers.scheduling_karras_ve.KarrasVeOutput`] or `tuple`:
[`~schedulers.scheduling_karras_ve.KarrasVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""

View File

@@ -0,0 +1,228 @@
# Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import flax
import jax.numpy as jnp
from jax import random
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
@flax.struct.dataclass
class KarrasVeSchedulerState:
# setable values
num_inference_steps: Optional[int] = None
timesteps: Optional[jnp.ndarray] = None
schedule: Optional[jnp.ndarray] = None # sigma(t_i)
@classmethod
def create(cls):
return cls()
@dataclass
class FlaxKarrasVeOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
derivative (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
Derivate of predicted original image sample (x_0).
state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
"""
prev_sample: jnp.ndarray
derivative: jnp.ndarray
state: KarrasVeSchedulerState
class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin):
"""
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
the VE column of Table 1 from [1] for reference.
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
differential equations." https://arxiv.org/abs/2011.13456
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
Args:
sigma_min (`float`): minimum noise magnitude
sigma_max (`float`): maximum noise magnitude
s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
A reasonable range is [1.000, 1.011].
s_churn (`float`): the parameter controlling the overall amount of stochasticity.
A reasonable range is [0, 100].
s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
A reasonable range is [0, 10].
s_max (`float`): the end value of the sigma range where we add noise.
A reasonable range is [0.2, 80].
"""
@register_to_config
def __init__(
self,
sigma_min: float = 0.02,
sigma_max: float = 100,
s_noise: float = 1.007,
s_churn: float = 80,
s_min: float = 0.05,
s_max: float = 50,
):
self.state = KarrasVeSchedulerState.create()
def set_timesteps(self, state: KarrasVeSchedulerState, num_inference_steps: int) -> KarrasVeSchedulerState:
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
state (`KarrasVeSchedulerState`):
the `FlaxKarrasVeScheduler` state data class.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
timesteps = jnp.arange(0, num_inference_steps)[::-1].copy()
schedule = [
(
self.config.sigma_max
* (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
)
for i in timesteps
]
return state.replace(
num_inference_steps=num_inference_steps,
schedule=jnp.array(schedule, dtype=jnp.float32),
timesteps=timesteps,
)
def add_noise_to_input(
self,
state: KarrasVeSchedulerState,
sample: jnp.ndarray,
sigma: float,
key: random.KeyArray,
) -> Tuple[jnp.ndarray, float]:
"""
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
TODO Args:
"""
if self.config.s_min <= sigma <= self.config.s_max:
gamma = min(self.config.s_churn / state.num_inference_steps, 2**0.5 - 1)
else:
gamma = 0
# sample eps ~ N(0, S_noise^2 * I)
key = random.split(key, num=1)
eps = self.config.s_noise * random.normal(key=key, shape=sample.shape)
sigma_hat = sigma + gamma * sigma
sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
return sample_hat, sigma_hat
def step(
self,
state: KarrasVeSchedulerState,
model_output: jnp.ndarray,
sigma_hat: float,
sigma_prev: float,
sample_hat: jnp.ndarray,
return_dict: bool = True,
) -> Union[FlaxKarrasVeOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
sigma_hat (`float`): TODO
sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
[`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] or `tuple`: Updated sample in the diffusion
chain and derivative. [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] if `return_dict` is
True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
pred_original_sample = sample_hat + sigma_hat * model_output
derivative = (sample_hat - pred_original_sample) / sigma_hat
sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative
if not return_dict:
return (sample_prev, derivative, state)
return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state)
def step_correct(
self,
state: KarrasVeSchedulerState,
model_output: jnp.ndarray,
sigma_hat: float,
sigma_prev: float,
sample_hat: jnp.ndarray,
sample_prev: jnp.ndarray,
derivative: jnp.ndarray,
return_dict: bool = True,
) -> Union[FlaxKarrasVeOutput, Tuple]:
"""
Correct the predicted sample based on the output model_output of the network. TODO complete description
Args:
state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
sigma_hat (`float`): TODO
sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
derivative (`torch.FloatTensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
"""
pred_original_sample = sample_prev + sigma_prev * model_output
derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
if not return_dict:
return (sample_prev, derivative, state)
return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state)
def add_noise(self, original_samples, noise, timesteps):
raise NotImplementedError()

View File

@@ -29,6 +29,11 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
Katherine Crowson:
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
@@ -36,10 +41,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional): TODO
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
timestep_values (`np.ndarry`, optional): TODO
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
"""
@@ -52,7 +57,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
timestep_values: Optional[np.ndarray] = None,
tensor_format: str = "pt",
):
if trained_betas is not None:
@@ -109,14 +113,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
self.num_inference_steps = num_inference_steps
self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
self.timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
low_idx = np.floor(self.timesteps).astype(int)
high_idx = np.ceil(self.timesteps).astype(int)
frac = np.mod(self.timesteps, 1.0)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
self.sigmas = np.concatenate([sigmas, [0.0]])
self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.derivatives = []
@@ -143,7 +147,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
sigma = self.sigmas[timestep]
@@ -177,6 +183,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
) -> Union[torch.FloatTensor, np.ndarray]:
if self.tensor_format == "pt":
timesteps = timesteps.to(self.sigmas.device)
sigmas = self.match_shape(self.sigmas[timesteps], noise)
noisy_samples = original_samples + noise * sigmas

View File

@@ -0,0 +1,207 @@
# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import flax
import jax.numpy as jnp
from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
@flax.struct.dataclass
class LMSDiscreteSchedulerState:
# setable values
num_inference_steps: Optional[int] = None
timesteps: Optional[jnp.ndarray] = None
sigmas: Optional[jnp.ndarray] = None
derivatives: jnp.ndarray = jnp.array([])
@classmethod
def create(cls, num_train_timesteps: int, sigmas: jnp.ndarray):
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1], sigmas=sigmas)
@dataclass
class FlaxSchedulerOutput(SchedulerOutput):
state: LMSDiscreteSchedulerState
class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
Katherine Crowson:
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`.
trained_betas (`jnp.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
"""
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[jnp.ndarray] = None,
):
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
if beta_schedule == "linear":
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
self.state = LMSDiscreteSchedulerState.create(
num_train_timesteps=num_train_timesteps, sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
)
def get_lms_coefficient(self, state, order, t, current_order):
"""
Compute a linear multistep coefficient.
Args:
order (TODO):
t (TODO):
current_order (TODO):
"""
def lms_derivative(tau):
prod = 1.0
for k in range(order):
if current_order == k:
continue
prod *= (tau - state.sigmas[t - k]) / (state.sigmas[t - current_order] - state.sigmas[t - k])
return prod
integrated_coeff = integrate.quad(lms_derivative, state.sigmas[t], state.sigmas[t + 1], epsrel=1e-4)[0]
return integrated_coeff
def set_timesteps(self, state: LMSDiscreteSchedulerState, num_inference_steps: int) -> LMSDiscreteSchedulerState:
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
state (`LMSDiscreteSchedulerState`):
the `FlaxLMSDiscreteScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=jnp.float32)
low_idx = jnp.floor(timesteps).astype(int)
high_idx = jnp.ceil(timesteps).astype(int)
frac = jnp.mod(timesteps, 1.0)
sigmas = jnp.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
sigmas = jnp.concatenate([sigmas, jnp.array([0.0])]).astype(jnp.float32)
return state.replace(
num_inference_steps=num_inference_steps,
timesteps=timesteps,
derivatives=jnp.array([]),
sigmas=sigmas,
)
def step(
self,
state: LMSDiscreteSchedulerState,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
order: int = 4,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
state (`LMSDiscreteSchedulerState`): the `FlaxLMSDiscreteScheduler` state data class instance.
model_output (`jnp.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
order: coefficient for multi-step inference.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
"""
sigma = state.sigmas[timestep]
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
pred_original_sample = sample - sigma * model_output
# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma
state = state.replace(derivatives=state.derivatives.append(derivative))
if len(state.derivatives) > order:
state = state.replace(derivatives=state.derivatives.pop(0))
# 3. Compute linear multistep coefficients
order = min(timestep + 1, order)
lms_coeffs = [self.get_lms_coefficient(state, order, timestep, curr_order) for curr_order in range(order)]
# 4. Compute previous sample based on the derivatives path
prev_sample = sample + sum(
coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(state.derivatives))
)
if not return_dict:
return (prev_sample, state)
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
def add_noise(
self,
state: LMSDiscreteSchedulerState,
original_samples: jnp.ndarray,
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sigmas = self.match_shape(state.sigmas[timesteps], noise)
noisy_samples = original_samples + noise * sigmas
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -15,6 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
import warnings
from typing import Optional, Tuple, Union
import numpy as np
@@ -58,6 +59,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
namely Runge-Kutta method and a linear multi-step method.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
For more details, see the original paper: https://arxiv.org/abs/2202.09778
Args:
@@ -67,11 +73,20 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, optional): TODO
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
skip_prk_steps (`bool`):
allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
before plms steps; defaults to `False`.
set_alpha_to_one (`bool`, default `False`):
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the value of alpha at step 0.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
"""
@@ -83,8 +98,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
tensor_format: str = "pt",
skip_prk_steps: bool = False,
set_alpha_to_one: bool = False,
steps_offset: int = 0,
tensor_format: str = "pt",
):
if trained_betas is not None:
self.betas = np.asarray(trained_betas)
@@ -102,7 +119,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.one = np.array(1.0)
self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
@@ -118,7 +135,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# setable values
self.num_inference_steps = None
self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self._offset = 0
self.prk_timesteps = None
self.plms_timesteps = None
self.timesteps = None
@@ -126,21 +142,31 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor:
def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
offset (`int`): TODO
"""
offset = self.config.steps_offset
if "offset" in kwargs:
warnings.warn(
"`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
" Please pass `steps_offset` to `__init__` instead."
)
offset = kwargs["offset"]
self.num_inference_steps = num_inference_steps
self._timesteps = list(
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
)
self._offset = offset
self._timesteps = np.array([t + self._offset for t in self._timesteps])
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()
self._timesteps += offset
if self.config.skip_prk_steps:
# for some models like stable diffusion the prk steps can/should be skipped to
@@ -186,7 +212,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
`SchedulerOutput`: updated sample in the diffusion chain.
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
@@ -213,7 +241,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
[`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
if self.num_inference_steps is None:
@@ -222,7 +251,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
)
diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1])
prev_timestep = timestep - diff_to_prev
timestep = self.prk_timesteps[self.counter // 4 * 4]
if self.counter % 4 == 0:
@@ -267,7 +296,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
[`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
if self.num_inference_steps is None:
@@ -283,7 +313,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
"for more information."
)
prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0)
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
if self.counter != 1:
self.ets.append(model_output)
@@ -313,7 +343,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample)
def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
# this function computes x_(tδ) using the formula of (9)
# Note that x_t needs to be added to both sides of the equation
@@ -326,8 +356,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# sample -> x_t
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(tδ)
alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset]
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - self._offset]
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
@@ -355,7 +385,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
) -> torch.Tensor:
if self.tensor_format == "pt":
timesteps = timesteps.to(self.alphas_cumprod.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5

View File

@@ -0,0 +1,406 @@
# Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import flax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
Returns:
betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return jnp.array(betas, dtype=jnp.float32)
@flax.struct.dataclass
class PNDMSchedulerState:
# setable values
_timesteps: jnp.ndarray
num_inference_steps: Optional[int] = None
_offset: int = 0
prk_timesteps: Optional[jnp.ndarray] = None
plms_timesteps: Optional[jnp.ndarray] = None
timesteps: Optional[jnp.ndarray] = None
# running values
cur_model_output: Optional[jnp.ndarray] = None
counter: int = 0
cur_sample: Optional[jnp.ndarray] = None
ets: jnp.ndarray = jnp.array([])
@classmethod
def create(cls, num_train_timesteps: int):
return cls(_timesteps=jnp.arange(0, num_train_timesteps)[::-1])
@dataclass
class FlaxSchedulerOutput(SchedulerOutput):
state: PNDMSchedulerState
class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
"""
Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
namely Runge-Kutta method and a linear multi-step method.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
For more details, see the original paper: https://arxiv.org/abs/2202.09778
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`jnp.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
skip_prk_steps (`bool`):
allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
before plms steps; defaults to `False`.
"""
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[jnp.ndarray] = None,
skip_prk_steps: bool = False,
):
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
if beta_schedule == "linear":
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
# For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# mainly at formula (9), (12), (13) and the Algorithm 2.
self.pndm_order = 4
self.state = PNDMSchedulerState.create(num_train_timesteps=num_train_timesteps)
def set_timesteps(
self, state: PNDMSchedulerState, num_inference_steps: int, offset: int = 0
) -> PNDMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
state (`PNDMSchedulerState`):
the `FlaxPNDMScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
offset (`int`):
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
"""
step_ratio = self.config.num_train_timesteps // num_inference_steps
# creates integer timesteps by multiplying by ratio
# rounding to avoid issues when num_inference_step is power of 3
_timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1]
_timesteps = _timesteps + offset
state = state.replace(num_inference_steps=num_inference_steps, _offset=offset, _timesteps=_timesteps)
if self.config.skip_prk_steps:
# for some models like stable diffusion the prk steps can/should be skipped to
# produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
# is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
state = state.replace(
prk_timesteps=jnp.array([]),
plms_timesteps=jnp.concatenate(
[state._timesteps[:-1], state._timesteps[-2:-1], state._timesteps[-1:]]
)[::-1],
)
else:
prk_timesteps = jnp.array(state._timesteps[-self.pndm_order :]).repeat(2) + jnp.tile(
jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
)
state = state.replace(
prk_timesteps=(prk_timesteps[:-1].repeat(2)[1:-1])[::-1],
plms_timesteps=state._timesteps[:-3][::-1],
)
return state.replace(
timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int64),
ets=jnp.array([]),
counter=0,
)
def step(
self,
state: PNDMSchedulerState,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
return_dict: bool = True,
) -> Union[FlaxSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.
Args:
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
model_output (`jnp.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
"""
if state.counter < len(state.prk_timesteps) and not self.config.skip_prk_steps:
return self.step_prk(
state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict
)
else:
return self.step_plms(
state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict
)
def step_prk(
self,
state: PNDMSchedulerState,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
return_dict: bool = True,
) -> Union[FlaxSchedulerOutput, Tuple]:
"""
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
solution to the differential equation.
Args:
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
model_output (`jnp.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
"""
if state.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2
prev_timestep = max(timestep - diff_to_prev, state.prk_timesteps[-1])
timestep = state.prk_timesteps[state.counter // 4 * 4]
if state.counter % 4 == 0:
state = state.replace(
cur_model_output=state.cur_model_output + 1 / 6 * model_output,
ets=state.ets.append(model_output),
cur_sample=sample,
)
elif (self.counter - 1) % 4 == 0:
state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output)
elif (self.counter - 2) % 4 == 0:
state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output)
elif (self.counter - 3) % 4 == 0:
model_output = state.cur_model_output + 1 / 6 * model_output
state = state.replace(cur_model_output=0)
# cur_sample should not be `None`
cur_sample = state.cur_sample if state.cur_sample is not None else sample
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output, state=state)
state = state.replace(counter=state.counter + 1)
if not return_dict:
return (prev_sample, state)
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
def step_plms(
self,
state: PNDMSchedulerState,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
return_dict: bool = True,
) -> Union[FlaxSchedulerOutput, Tuple]:
"""
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
times to approximate the solution.
Args:
state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance.
model_output (`jnp.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
"""
if state.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if not self.config.skip_prk_steps and len(state.ets) < 3:
raise ValueError(
f"{self.__class__} can only be run AFTER scheduler has been run "
"in 'prk' mode for at least 12 iterations "
"See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
"for more information."
)
prev_timestep = max(timestep - self.config.num_train_timesteps // state.num_inference_steps, 0)
if state.counter != 1:
state = state.replace(ets=state.ets.append(model_output))
else:
prev_timestep = timestep
timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps
if len(state.ets) == 1 and state.counter == 0:
model_output = model_output
state = state.replace(cur_sample=sample)
elif len(state.ets) == 1 and state.counter == 1:
model_output = (model_output + state.ets[-1]) / 2
sample = state.cur_sample
state = state.replace(cur_sample=None)
elif len(state.ets) == 2:
model_output = (3 * state.ets[-1] - state.ets[-2]) / 2
elif len(state.ets) == 3:
model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12
else:
model_output = (1 / 24) * (
55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]
)
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output, state=state)
state = state.replace(counter=state.counter + 1)
if not return_dict:
return (prev_sample, state)
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
def _get_prev_sample(self, sample, timestep, timestep_prev, model_output, state):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
# this function computes x_(tδ) using the formula of (9)
# Note that x_t needs to be added to both sides of the equation
# Notation (<variable name> -> <name in paper>
# alpha_prod_t -> α_t
# alpha_prod_t_prev -> α_(tδ)
# beta_prod_t -> (1 - α_t)
# beta_prod_t_prev -> (1 - α_(tδ))
# sample -> x_t
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(tδ)
alpha_prod_t = self.alphas_cumprod[timestep + 1 - state._offset]
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - state._offset]
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
# corresponds to (α_(tδ) - α_t) divided by
# denominator of x_t in formula (9) and plus 1
# Note: (α_(tδ) - α_t) / (sqrt(α_t) * (sqrt(α_(tδ)) + sqr(α_t))) =
# sqrt(α_(tδ)) / sqrt(α_t))
sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)
# corresponds to denominator of e_θ(x_t, t) in formula (9)
model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
alpha_prod_t * beta_prod_t * alpha_prod_t_prev
) ** (0.5)
# full formula (9)
prev_sample = (
sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
)
return prev_sample
def add_noise(
self,
original_samples: jnp.ndarray,
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -49,14 +49,20 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
For more information, see the original paper: https://arxiv.org/abs/2011.13456
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
snr (`float`):
coefficient weighting the step from the model_output sample (from the network) to the random noise.
sigma_min (`float`):
initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
distribution of the data.
sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model.
sampling_eps (`float`): the end value of sampling, where timesteps decrease progessively from 1 to
sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to
epsilon.
correct_steps (`int`): number of correction steps performed on a produced sample.
tensor_format (`str`): "np" or "pt" for the expected format of samples passed to the Scheduler.
@@ -139,7 +145,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1])
elif tensor_format == "pt":
return torch.where(
timesteps == 0, torch.zeros_like(t), self.discrete_sigmas[timesteps - 1].to(timesteps.device)
timesteps == 0,
torch.zeros_like(t.to(timesteps.device)),
self.discrete_sigmas[timesteps - 1].to(timesteps.device),
)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
@@ -180,7 +188,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
[`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
if "seed" in kwargs and kwargs["seed"] is not None:
@@ -196,8 +205,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
) # torch.repeat_interleave(timestep, sample.shape[0])
timesteps = (timestep * (len(self.timesteps) - 1)).long()
# mps requires indices to be in the same device, so we use cpu as is the default with cuda
timesteps = timesteps.to(self.discrete_sigmas.device)
sigma = self.discrete_sigmas[timesteps].to(sample.device)
adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep)
adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device)
drift = self.zeros_like(sample)
diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5
@@ -236,7 +248,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.
[`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
if "seed" in kwargs and kwargs["seed"] is not None:

View File

@@ -0,0 +1,260 @@
# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import flax
import jax.numpy as jnp
from jax import random
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
@flax.struct.dataclass
class ScoreSdeVeSchedulerState:
# setable values
timesteps: Optional[jnp.ndarray] = None
discrete_sigmas: Optional[jnp.ndarray] = None
sigmas: Optional[jnp.ndarray] = None
@classmethod
def create(cls):
return cls()
@dataclass
class FlaxSdeVeOutput(SchedulerOutput):
"""
Output class for the ScoreSdeVeScheduler's step function output.
Args:
state (`ScoreSdeVeSchedulerState`):
prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
prev_sample_mean (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps.
"""
state: ScoreSdeVeSchedulerState
prev_sample: jnp.ndarray
prev_sample_mean: Optional[jnp.ndarray] = None
class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
"""
The variance exploding stochastic differential equation (SDE) scheduler.
For more information, see the original paper: https://arxiv.org/abs/2011.13456
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
snr (`float`):
coefficient weighting the step from the model_output sample (from the network) to the random noise.
sigma_min (`float`):
initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
distribution of the data.
sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model.
sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to
epsilon.
correct_steps (`int`): number of correction steps performed on a produced sample.
"""
@register_to_config
def __init__(
self,
num_train_timesteps: int = 2000,
snr: float = 0.15,
sigma_min: float = 0.01,
sigma_max: float = 1348.0,
sampling_eps: float = 1e-5,
correct_steps: int = 1,
):
state = ScoreSdeVeSchedulerState.create()
self.state = self.set_sigmas(state, num_train_timesteps, sigma_min, sigma_max, sampling_eps)
def set_timesteps(
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, sampling_eps: float = None
) -> ScoreSdeVeSchedulerState:
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
"""
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
timesteps = jnp.linspace(1, sampling_eps, num_inference_steps)
return state.replace(timesteps=timesteps)
def set_sigmas(
self,
state: ScoreSdeVeSchedulerState,
num_inference_steps: int,
sigma_min: float = None,
sigma_max: float = None,
sampling_eps: float = None,
) -> ScoreSdeVeSchedulerState:
"""
Sets the noise scales used for the diffusion chain. Supporting function to be run before inference.
The sigmas control the weight of the `drift` and `diffusion` components of sample update.
Args:
state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
sigma_min (`float`, optional):
initial noise scale value (overrides value given at Scheduler instantiation).
sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation).
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
"""
sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
if state.timesteps is None:
state = self.set_timesteps(state, num_inference_steps, sampling_eps)
discrete_sigmas = jnp.exp(jnp.linspace(jnp.log(sigma_min), jnp.log(sigma_max), num_inference_steps))
sigmas = jnp.array([sigma_min * (sigma_max / sigma_min) ** t for t in state.timesteps])
return state.replace(discrete_sigmas=discrete_sigmas, sigmas=sigmas)
def get_adjacent_sigma(self, state, timesteps, t):
return jnp.where(timesteps == 0, jnp.zeros_like(t), state.discrete_sigmas[timesteps - 1])
def step_pred(
self,
state: ScoreSdeVeSchedulerState,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
key: random.KeyArray,
return_dict: bool = True,
) -> Union[FlaxSdeVeOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
model_output (`jnp.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
[`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
if state.timesteps is None:
raise ValueError(
"`state.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
)
timestep = timestep * jnp.ones(
sample.shape[0],
)
timesteps = (timestep * (len(state.timesteps) - 1)).long()
sigma = state.discrete_sigmas[timesteps]
adjacent_sigma = self.get_adjacent_sigma(state, timesteps, timestep)
drift = jnp.zeros_like(sample)
diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5
# equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
# also equation 47 shows the analog from SDE models to ancestral sampling methods
drift = drift - diffusion[:, None, None, None] ** 2 * model_output
# equation 6: sample noise for the diffusion term of
key = random.split(key, num=1)
noise = random.normal(key=key, shape=sample.shape)
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
# TODO is the variable diffusion the correct scaling term for the noise?
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
if not return_dict:
return (prev_sample, prev_sample_mean, state)
return FlaxSdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean, state=state)
def step_correct(
self,
state: ScoreSdeVeSchedulerState,
model_output: jnp.ndarray,
sample: jnp.ndarray,
key: random.KeyArray,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
after making the prediction for the previous timestep.
Args:
state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
model_output (`jnp.ndarray`): direct output from learned diffusion model.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
[`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
if state.timesteps is None:
raise ValueError(
"`state.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
)
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
# sample noise for correction
key = random.split(key, num=1)
noise = random.normal(key=key, shape=sample.shape)
# compute step size from the model_output, the noise, and the snr
grad_norm = jnp.linalg.norm(model_output)
noise_norm = jnp.linalg.norm(noise)
step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
step_size = step_size * jnp.ones(sample.shape[0])
# compute corrected sample: model_output term and noise term
prev_sample_mean = sample + step_size[:, None, None, None] * model_output
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
if not return_dict:
return (prev_sample, state)
return FlaxSdeVeOutput(prev_sample=prev_sample, state=state)
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -14,7 +14,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
# TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit
import numpy as np
import torch
@@ -27,6 +27,11 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
"""
The variance preserving stochastic differential equation (SDE) scheduler.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
For more information, see the original paper: https://arxiv.org/abs/2011.13456
UNDER CONSTRUCTION
@@ -35,7 +40,6 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
@register_to_config
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"):
self.sigmas = None
self.discrete_sigmas = None
self.timesteps = None

View File

@@ -2,12 +2,22 @@ import os
import random
import unittest
from distutils.util import strtobool
from typing import Union
import torch
import PIL.Image
import PIL.ImageOps
import requests
from packaging import version
global_rng = random.Random()
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12")
if is_torch_higher_equal_than_1_12:
torch_device = "mps" if torch.backends.mps.is_available() else torch_device
def parse_flag_from_env(key, default=False):
@@ -53,3 +63,32 @@ def slow(test_case):
"""
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
"""
Args:
Loads `image` to a PIL Image.
image (`str` or `PIL.Image.Image`):
The image to convert to the PIL Image format.
Returns:
`PIL.Image.Image`: A PIL Image.
"""
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
image = PIL.Image.open(requests.get(image, stream=True).raw)
elif os.path.isfile(image):
image = PIL.Image.open(image)
else:
raise ValueError(
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
)
elif isinstance(image, PIL.Image.Image):
image = image
else:
raise ValueError(
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
return image

View File

@@ -25,6 +25,7 @@ from .import_utils import (
is_flax_available,
is_inflect_available,
is_modelcards_available,
is_onnx_available,
is_scipy_available,
is_tf_available,
is_torch_available,

View File

@@ -0,0 +1,60 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from ..utils import DummyObject, requires_backends
class FlaxModelMixin(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxDDIMScheduler(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxDDPMScheduler(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxKarrasVeScheduler(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxLMSDiscreteScheduler(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxPNDMScheduler(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxUNet2DConditionModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxScoreSdeVeScheduler(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])

View File

@@ -0,0 +1,165 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from ..utils import DummyObject, requires_backends
class ModelMixin(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class AutoencoderKL(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class UNet2DConditionModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class UNet2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class VQModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
def get_constant_schedule(*args, **kwargs):
requires_backends(get_constant_schedule, ["torch"])
def get_constant_schedule_with_warmup(*args, **kwargs):
requires_backends(get_constant_schedule_with_warmup, ["torch"])
def get_cosine_schedule_with_warmup(*args, **kwargs):
requires_backends(get_cosine_schedule_with_warmup, ["torch"])
def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs):
requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch"])
def get_linear_schedule_with_warmup(*args, **kwargs):
requires_backends(get_linear_schedule_with_warmup, ["torch"])
def get_polynomial_decay_schedule_with_warmup(*args, **kwargs):
requires_backends(get_polynomial_decay_schedule_with_warmup, ["torch"])
def get_scheduler(*args, **kwargs):
requires_backends(get_scheduler, ["torch"])
class DiffusionPipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DDIMPipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DDPMPipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class KarrasVePipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LDMPipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PNDMPipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ScoreSdeVePipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DDIMScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DDPMScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class KarrasVeScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class PNDMScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class SchedulerMixin(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ScoreSdeVeScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class EMAModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

View File

@@ -5,7 +5,7 @@ from ..utils import DummyObject, requires_backends
class LMSDiscreteScheduler(metaclass=DummyObject):
_backends = ["scipy"]
_backends = ["torch", "scipy"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["scipy"])
requires_backends(self, ["torch", "scipy"])

Some files were not shown because too many files have changed in this diff Show More