diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml
index 542920d7f6..d51623e735 100644
--- a/.github/workflows/build_pr_documentation.yml
+++ b/.github/workflows/build_pr_documentation.yml
@@ -9,11 +9,8 @@ concurrency:
jobs:
build:
- uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@use_hf_hub
+ uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
with:
commit_sha: ${{ github.event.pull_request.head.sha }}
pr_number: ${{ github.event.number }}
package: diffusers
- secrets:
- token: ${{ secrets.HF_DOC_PUSH }}
- comment_bot_token: ${{ secrets.HUGGINGFACE_PUSH }}
diff --git a/.github/workflows/delete_doc_comment.yml b/.github/workflows/delete_doc_comment.yml
index e1b2da9567..238dc0bdba 100644
--- a/.github/workflows/delete_doc_comment.yml
+++ b/.github/workflows/delete_doc_comment.yml
@@ -7,10 +7,7 @@ on:
jobs:
delete:
- uses: huggingface/doc-builder/.github/workflows/delete_doc_comment.yml@use_hf_hub
+ uses: huggingface/doc-builder/.github/workflows/delete_doc_comment.yml@main
with:
pr_number: ${{ github.event.number }}
package: diffusers
- secrets:
- token: ${{ secrets.HF_DOC_PUSH }}
- comment_bot_token: ${{ secrets.HUGGINGFACE_PUSH }}
diff --git a/.gitignore b/.gitignore
index cf81834636..f018a111ea 100644
--- a/.gitignore
+++ b/.gitignore
@@ -163,4 +163,6 @@ tags
*.lock
# DS_Store (MacOS)
-.DS_Store
\ No newline at end of file
+.DS_Store
+# RL pipelines may produce mp4 outputs
+*.mp4
\ No newline at end of file
diff --git a/README.md b/README.md
index 64cbd15aab..4a944d0459 100644
--- a/README.md
+++ b/README.md
@@ -152,15 +152,7 @@ it before the pipeline and pass it to `from_pretrained`.
```python
from diffusers import LMSDiscreteScheduler
-lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
-
-pipe = StableDiffusionPipeline.from_pretrained(
- "runwayml/stable-diffusion-v1-5",
- revision="fp16",
- torch_dtype=torch.float16,
- scheduler=lms,
-)
-pipe = pipe.to("cuda")
+pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index d8efb5eee3..4491a1eab6 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -10,6 +10,8 @@
- sections:
- local: using-diffusers/loading
title: "Loading Pipelines, Models, and Schedulers"
+ - local: using-diffusers/schedulers
+ title: "Using different Schedulers"
- local: using-diffusers/configuration
title: "Configuring Pipelines, Models, and Schedulers"
- local: using-diffusers/custom_pipeline_overview
@@ -78,6 +80,8 @@
- sections:
- local: api/pipelines/overview
title: "Overview"
+ - local: api/pipelines/alt_diffusion
+ title: "AltDiffusion"
- local: api/pipelines/cycle_diffusion
title: "Cycle Diffusion"
- local: api/pipelines/ddim
diff --git a/docs/source/api/configuration.mdx b/docs/source/api/configuration.mdx
index 45176f55b0..423c31f462 100644
--- a/docs/source/api/configuration.mdx
+++ b/docs/source/api/configuration.mdx
@@ -15,9 +15,9 @@ specific language governing permissions and limitations under the License.
In Diffusers, schedulers of type [`schedulers.scheduling_utils.SchedulerMixin`], and models of type [`ModelMixin`] inherit from [`ConfigMixin`] which conveniently takes care of storing all parameters that are
passed to the respective `__init__` methods in a JSON-configuration file.
-TODO(PVP) - add example and better info here
-
## ConfigMixin
+
[[autodoc]] ConfigMixin
+ - load_config
- from_config
- save_config
diff --git a/docs/source/api/models.mdx b/docs/source/api/models.mdx
index 2e1e8798a7..7c1faa8474 100644
--- a/docs/source/api/models.mdx
+++ b/docs/source/api/models.mdx
@@ -22,12 +22,15 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
## UNet2DOutput
[[autodoc]] models.unet_2d.UNet2DOutput
-## UNet1DModel
-[[autodoc]] UNet1DModel
-
## UNet2DModel
[[autodoc]] UNet2DModel
+## UNet1DOutput
+[[autodoc]] models.unet_1d.UNet1DOutput
+
+## UNet1DModel
+[[autodoc]] UNet1DModel
+
## UNet2DConditionOutput
[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput
diff --git a/docs/source/api/pipelines/alt_diffusion.mdx b/docs/source/api/pipelines/alt_diffusion.mdx
new file mode 100644
index 0000000000..efa9beb8c0
--- /dev/null
+++ b/docs/source/api/pipelines/alt_diffusion.mdx
@@ -0,0 +1,83 @@
+
+
+# AltDiffusion
+
+AltDiffusion was proposed in [AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities](https://arxiv.org/abs/2211.06679) by Zhongzhi Chen, Guang Liu, Bo-Wen Zhang, Fulong Ye, Qinghong Yang, Ledell Wu
+
+The abstract of the paper is the following:
+
+*In this work, we present a conceptually simple and effective method to train a strong bilingual multimodal representation model. Starting from the pretrained multimodal representation model CLIP released by OpenAI, we switched its text encoder with a pretrained multilingual text encoder XLM-R, and aligned both languages and image representations by a two-stage training schema consisting of teacher learning and contrastive learning. We validate our method through evaluations of a wide range of tasks. We set new state-of-the-art performances on a bunch of tasks including ImageNet-CN, Flicker30k- CN, and COCO-CN. Further, we obtain very close performances with CLIP on almost all tasks, suggesting that one can simply alter the text encoder in CLIP for extended capabilities such as multilingual understanding.*
+
+
+*Overview*:
+
+| Pipeline | Tasks | Colab | Demo
+|---|---|:---:|:---:|
+| [pipeline_alt_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py) | *Text-to-Image Generation* | - | -
+| [pipeline_alt_diffusion_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py) | *Image-to-Image Text-Guided Generation* | - |-
+
+## Tips
+
+- AltDiffusion is conceptually exaclty the same as [Stable Diffusion](./api/pipelines/stable_diffusion).
+
+- *Run AltDiffusion*
+
+AltDiffusion can be tested very easily with the [`AltDiffusionPipeline`], [`AltDiffusionImg2ImgPipeline`] and the `"BAAI/AltDiffusion"` checkpoint exactly in the same way it is shown in the [Conditional Image Generation Guide](./using-diffusers/conditional_image_generation) and the [Image-to-Image Generation Guide](./using-diffusers/img2img).
+
+- *How to load and use different schedulers.*
+
+The alt diffusion pipeline uses [`DDIMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the alt diffusion pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
+To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
+
+```python
+>>> from diffusers import AltDiffusionPipeline, EulerDiscreteScheduler
+
+>>> pipeline = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion")
+>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
+
+>>> # or
+>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("BAAI/AltDiffusion", subfolder="scheduler")
+>>> pipeline = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion", scheduler=euler_scheduler)
+```
+
+
+- *How to conver all use cases with multiple or single pipeline*
+
+If you want to use all possible use cases in a single `DiffusionPipeline` we recommend using the `components` functionality to instantiate all components in the most memory-efficient way:
+
+```python
+>>> from diffusers import (
+... AltDiffusionPipeline,
+... AltDiffusionImg2ImgPipeline,
+... )
+
+>>> img2text = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion")
+>>> img2img = AltDiffusionImg2ImgPipeline(**img2text.components)
+
+>>> # now you can use img2text(...) and img2img(...) just like the call methods of each respective pipeline
+```
+
+## AltDiffusionPipelineOutput
+[[autodoc]] pipelines.alt_diffusion.AltDiffusionPipelineOutput
+
+## AltDiffusionPipeline
+[[autodoc]] AltDiffusionPipeline
+ - __call__
+ - enable_attention_slicing
+ - disable_attention_slicing
+
+## AltDiffusionImg2ImgPipeline
+[[autodoc]] AltDiffusionImg2ImgPipeline
+ - __call__
+ - enable_attention_slicing
+ - disable_attention_slicing
diff --git a/docs/source/api/pipelines/cycle_diffusion.mdx b/docs/source/api/pipelines/cycle_diffusion.mdx
index 50d2a5c87e..8eecd3d624 100644
--- a/docs/source/api/pipelines/cycle_diffusion.mdx
+++ b/docs/source/api/pipelines/cycle_diffusion.mdx
@@ -39,7 +39,7 @@ from diffusers import CycleDiffusionPipeline, DDIMScheduler
# load the pipeline
# make sure you're logged in with `huggingface-cli login`
model_id_or_path = "CompVis/stable-diffusion-v1-4"
-scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler")
+scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler")
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")
# let's download an initial image
diff --git a/docs/source/api/pipelines/latent_diffusion.mdx b/docs/source/api/pipelines/latent_diffusion.mdx
index 4ade13e67b..370d014f5a 100644
--- a/docs/source/api/pipelines/latent_diffusion.mdx
+++ b/docs/source/api/pipelines/latent_diffusion.mdx
@@ -39,9 +39,9 @@ The original codebase can be found [here](https://github.com/CompVis/latent-diff
## LDMTextToImagePipeline
-[[autodoc]] pipelines.latent_diffusion.pipeline_latent_diffusion.LDMTextToImagePipeline
+[[autodoc]] LDMTextToImagePipeline
- __call__
## LDMSuperResolutionPipeline
-[[autodoc]] pipelines.latent_diffusion.pipeline_latent_diffusion_superresolution.LDMSuperResolutionPipeline
+[[autodoc]] LDMSuperResolutionPipeline
- __call__
diff --git a/docs/source/api/pipelines/overview.mdx b/docs/source/api/pipelines/overview.mdx
index d68961a2fc..74c44fbccd 100644
--- a/docs/source/api/pipelines/overview.mdx
+++ b/docs/source/api/pipelines/overview.mdx
@@ -44,11 +44,13 @@ available a colab notebook to directly try them out.
| Pipeline | Paper | Tasks | Colab
|---|---|:---:|:---:|
+| [alt_diffusion](./api/pipelines/alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | -
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
+| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image |
| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
| [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
| [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
diff --git a/docs/source/api/pipelines/repaint.mdx b/docs/source/api/pipelines/repaint.mdx
index 0b7de8a457..ce262daffa 100644
--- a/docs/source/api/pipelines/repaint.mdx
+++ b/docs/source/api/pipelines/repaint.mdx
@@ -54,7 +54,7 @@ original_image = download_image(img_url).resize((256, 256))
mask_image = download_image(mask_url).resize((256, 256))
# Load the RePaint scheduler and pipeline based on a pretrained DDPM model
-scheduler = RePaintScheduler.from_config("google/ddpm-ema-celebahq-256")
+scheduler = RePaintScheduler.from_pretrained("google/ddpm-ema-celebahq-256")
pipe = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=scheduler)
pipe = pipe.to("cuda")
diff --git a/docs/source/api/pipelines/stable_diffusion.mdx b/docs/source/api/pipelines/stable_diffusion.mdx
index 26d6a210ad..1d22024a53 100644
--- a/docs/source/api/pipelines/stable_diffusion.mdx
+++ b/docs/source/api/pipelines/stable_diffusion.mdx
@@ -34,13 +34,17 @@ For more details about how Stable Diffusion works and how it differs from the ba
### How to load and use different schedulers.
The stable diffusion pipeline uses [`PNDMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
-To use a different scheduler, you can pass the `scheduler` argument to `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
+To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
```python
-from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
+>>> from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
-euler_scheduler = EulerDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
-pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=euler_scheduler)
+>>> pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
+>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
+
+>>> # or
+>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
+>>> pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=euler_scheduler)
```
diff --git a/docs/source/index.mdx b/docs/source/index.mdx
index bae507ac11..e4722bec68 100644
--- a/docs/source/index.mdx
+++ b/docs/source/index.mdx
@@ -34,11 +34,13 @@ available a colab notebook to directly try them out.
| Pipeline | Paper | Tasks | Colab
|---|---|:---:|:---:|
+| [alt_diffusion](./api/pipelines/alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation |
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
+| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image |
| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
| [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
| [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
diff --git a/docs/source/quicktour.mdx b/docs/source/quicktour.mdx
index 463780a072..a50b476c3d 100644
--- a/docs/source/quicktour.mdx
+++ b/docs/source/quicktour.mdx
@@ -41,7 +41,7 @@ In this guide though, you'll use [`DiffusionPipeline`] for text-to-image generat
```python
>>> from diffusers import DiffusionPipeline
->>> generator = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
+>>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
```
The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components.
@@ -49,13 +49,13 @@ Because the model consists of roughly 1.4 billion parameters, we strongly recomm
You can move the generator object to GPU, just like you would in PyTorch.
```python
->>> generator.to("cuda")
+>>> pipeline.to("cuda")
```
-Now you can use the `generator` on your text prompt:
+Now you can use the `pipeline` on your text prompt:
```python
->>> image = generator("An image of a squirrel in Picasso style").images[0]
+>>> image = pipeline("An image of a squirrel in Picasso style").images[0]
```
The output is by default wrapped into a [PIL Image object](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class).
@@ -82,7 +82,7 @@ just like we did before only that now you need to pass your `AUTH_TOKEN`:
```python
>>> from diffusers import DiffusionPipeline
->>> generator = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN)
+>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN)
```
If you do not pass your authentication token you will see that the diffusion system will not be correctly
@@ -102,7 +102,7 @@ token. Assuming that `"./stable-diffusion-v1-5"` is the local path to the cloned
you can also load the pipeline as follows:
```python
->>> generator = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5")
+>>> pipeline = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5")
```
Running the pipeline is then identical to the code above as it's the same model architecture.
@@ -115,19 +115,20 @@ Running the pipeline is then identical to the code above as it's the same model
Diffusion systems can be used with multiple different [schedulers](./api/schedulers) each with their
pros and cons. By default, Stable Diffusion runs with [`PNDMScheduler`], but it's very simple to
-use a different scheduler. *E.g.* if you would instead like to use the [`LMSDiscreteScheduler`] scheduler,
+use a different scheduler. *E.g.* if you would instead like to use the [`EulerDiscreteScheduler`] scheduler,
you could use it as follows:
```python
->>> from diffusers import LMSDiscreteScheduler
+>>> from diffusers import EulerDiscreteScheduler
->>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
+>>> pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN)
->>> generator = StableDiffusionPipeline.from_pretrained(
-... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, use_auth_token=AUTH_TOKEN
-... )
+>>> # change scheduler to Euler
+>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
```
+For more in-detail information on how to change between schedulers, please refer to the [Using Schedulers](./using-diffusers/schedulers) guide.
+
[Stability AI's](https://stability.ai/) Stable Diffusion model is an impressive image generation model
and can do much more than just generating images from text. We have dedicated a whole documentation page,
just for Stable Diffusion [here](./conceptual/stable_diffusion).
diff --git a/docs/source/training/text2image.mdx b/docs/source/training/text2image.mdx
index 1b04462f77..eb71457cb7 100644
--- a/docs/source/training/text2image.mdx
+++ b/docs/source/training/text2image.mdx
@@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License.
# Stable Diffusion text-to-image fine-tuning
-The [`train_text_to_image.py`](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion) script shows how to fine-tune the stable diffusion model on your own dataset.
+The [`train_text_to_image.py`](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image) script shows how to fine-tune the stable diffusion model on your own dataset.
diff --git a/docs/source/using-diffusers/conditional_image_generation.mdx b/docs/source/using-diffusers/conditional_image_generation.mdx
index 6273a71d4c..5ed27ac917 100644
--- a/docs/source/using-diffusers/conditional_image_generation.mdx
+++ b/docs/source/using-diffusers/conditional_image_generation.mdx
@@ -44,5 +44,3 @@ You can save the image by simply calling:
```python
>>> image.save("image_of_squirrel_painting.png")
```
-
-
diff --git a/docs/source/using-diffusers/loading.mdx b/docs/source/using-diffusers/loading.mdx
index 2cb980ea61..c97ad5c5d0 100644
--- a/docs/source/using-diffusers/loading.mdx
+++ b/docs/source/using-diffusers/loading.mdx
@@ -19,7 +19,7 @@ In the following we explain in-detail how to easily load:
- *Complete Diffusion Pipelines* via the [`DiffusionPipeline.from_pretrained`]
- *Diffusion Models* via [`ModelMixin.from_pretrained`]
-- *Schedulers* via [`ConfigMixin.from_config`]
+- *Schedulers* via [`SchedulerMixin.from_pretrained`]
## Loading pipelines
@@ -137,15 +137,15 @@ from diffusers import DiffusionPipeline, EulerDiscreteScheduler, DPMSolverMultis
repo_id = "runwayml/stable-diffusion-v1-5"
-scheduler = EulerDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
+scheduler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
# or
-# scheduler = DPMSolverMultistepScheduler.from_config(repo_id, subfolder="scheduler")
+# scheduler = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler")
stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, scheduler=scheduler)
```
Three things are worth paying attention to here.
-- First, the scheduler is loaded with [`ConfigMixin.from_config`] since it only depends on a configuration file and not any parameterized weights
+- First, the scheduler is loaded with [`SchedulerMixin.from_pretrained`]
- Second, the scheduler is loaded with a function argument, called `subfolder="scheduler"` as the configuration of stable diffusion's scheduling is defined in a [subfolder of the official pipeline repository](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/scheduler)
- Third, the scheduler instance can simply be passed with the `scheduler` keyword argument to [`DiffusionPipeline.from_pretrained`]. This works because the [`StableDiffusionPipeline`] defines its scheduler with the `scheduler` attribute. It's not possible to use a different name, such as `sampler=scheduler` since `sampler` is not a defined keyword for [`StableDiffusionPipeline.__init__`]
@@ -337,8 +337,8 @@ model = UNet2DModel.from_pretrained(repo_id)
## Loading schedulers
-Schedulers cannot be loaded via a `from_pretrained` method, but instead rely on [`ConfigMixin.from_config`]. Schedulers are **not parameterized** or **trained**, but instead purely defined by a configuration file.
-Therefore the loading method was given a different name here.
+Schedulers rely on [`SchedulerMixin.from_pretrained`]. Schedulers are **not parameterized** or **trained**, but instead purely defined by a configuration file.
+For consistency, we use the same method name as we do for models or pipelines, but no weights are loaded in this case.
In constrast to pipelines or models, loading schedulers does not consume any significant amount of memory and the same configuration file can often be used for a variety of different schedulers.
For example, all of:
@@ -367,13 +367,13 @@ from diffusers import (
repo_id = "runwayml/stable-diffusion-v1-5"
-ddpm = DDPMScheduler.from_config(repo_id, subfolder="scheduler")
-ddim = DDIMScheduler.from_config(repo_id, subfolder="scheduler")
-pndm = PNDMScheduler.from_config(repo_id, subfolder="scheduler")
-lms = LMSDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
-euler_anc = EulerAncestralDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
-euler = EulerDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
-dpm = DPMSolverMultistepScheduler.from_config(repo_id, subfolder="scheduler")
+ddpm = DDPMScheduler.from_pretrained(repo_id, subfolder="scheduler")
+ddim = DDIMScheduler.from_pretrained(repo_id, subfolder="scheduler")
+pndm = PNDMScheduler.from_pretrained(repo_id, subfolder="scheduler")
+lms = LMSDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
+euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
+euler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
+dpm = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler")
# replace `dpm` with any of `ddpm`, `ddim`, `pndm`, `lms`, `euler`, `euler_anc`
pipeline = StableDiffusionPipeline.from_pretrained(repo_id, scheduler=dpm)
diff --git a/docs/source/using-diffusers/schedulers.mdx b/docs/source/using-diffusers/schedulers.mdx
new file mode 100644
index 0000000000..87ff789747
--- /dev/null
+++ b/docs/source/using-diffusers/schedulers.mdx
@@ -0,0 +1,262 @@
+
+
+# Schedulers
+
+Diffusion pipelines are inherently a collection of diffusion models and schedulers that are partly independent from each other. This means that one is able to switch out parts of the pipeline to better customize
+a pipeline to one's use case. The best example of this are the [Schedulers](../api/schedulers.mdx).
+
+Whereas diffusion models usually simply define the forward pass from noise to a less noisy sample,
+schedulers define the whole denoising process, *i.e.*:
+- How many denoising steps?
+- Stochastic or deterministic?
+- What algorithm to use to find the denoised sample
+
+They can be quite complex and often define a trade-off between **denoising speed** and **denoising quality**.
+It is extremely difficult to measure quantitatively which scheduler works best for a given diffusion pipeline, so it is often recommended to simply try out which works best.
+
+The following paragraphs shows how to do so with the 🧨 Diffusers library.
+
+## Load pipeline
+
+Let's start by loading the stable diffusion pipeline.
+Remember that you have to be a registered user on the 🤗 Hugging Face Hub, and have "click-accepted" the [license](https://huggingface.co/runwayml/stable-diffusion-v1-5) in order to use stable diffusion.
+
+```python
+from huggingface_hub import login
+from diffusers import DiffusionPipeline
+import torch
+
+# first we need to login with our access token
+login()
+
+# Now we can download the pipeline
+pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+```
+
+Next, we move it to GPU:
+
+```python
+pipeline.to("cuda")
+```
+
+## Access the scheduler
+
+The scheduler is always one of the components of the pipeline and is usually called `"scheduler"`.
+So it can be accessed via the `"scheduler"` property.
+
+```python
+pipeline.scheduler
+```
+
+**Output**:
+```
+PNDMScheduler {
+ "_class_name": "PNDMScheduler",
+ "_diffusers_version": "0.8.0.dev0",
+ "beta_end": 0.012,
+ "beta_schedule": "scaled_linear",
+ "beta_start": 0.00085,
+ "clip_sample": false,
+ "num_train_timesteps": 1000,
+ "set_alpha_to_one": false,
+ "skip_prk_steps": true,
+ "steps_offset": 1,
+ "trained_betas": null
+}
+```
+
+We can see that the scheduler is of type [`PNDMScheduler`].
+Cool, now let's compare the scheduler in its performance to other schedulers.
+First we define a prompt on which we will test all the different schedulers:
+
+```python
+prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition."
+```
+
+Next, we create a generator from a random seed that will ensure that we can generate similar images as well as run the pipeline:
+
+```python
+generator = torch.Generator(device="cuda").manual_seed(8)
+image = pipeline(prompt, generator=generator).images[0]
+image
+```
+
+
+
+
+
+
+
+
+## Changing the scheduler
+
+Now we show how easy it is to change the scheduler of a pipeline. Every scheduler has a property [`SchedulerMixin.compatibles`]
+which defines all compatible schedulers. You can take a look at all available, compatible schedulers for the Stable Diffusion pipeline as follows.
+
+```python
+pipeline.scheduler.compatibles
+```
+
+**Output**:
+```
+[diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,
+ diffusers.schedulers.scheduling_ddim.DDIMScheduler,
+ diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,
+ diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,
+ diffusers.schedulers.scheduling_pndm.PNDMScheduler,
+ diffusers.schedulers.scheduling_ddpm.DDPMScheduler,
+ diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler]
+```
+
+Cool, lots of schedulers to look at. Feel free to have a look at their respective class definitions:
+
+- [`LMSDiscreteScheduler`],
+- [`DDIMScheduler`],
+- [`DPMSolverMultistepScheduler`],
+- [`EulerDiscreteScheduler`],
+- [`PNDMScheduler`],
+- [`DDPMScheduler`],
+- [`EulerAncestralDiscreteScheduler`].
+
+We will now compare the input prompt with all other schedulers. To change the scheduler of the pipeline you can make use of the
+convenient [`ConfigMixin.config`] property in combination with the [`ConfigMixin.from_config`] function.
+
+```python
+pipeline.scheduler.config
+```
+
+returns a dictionary of the configuration of the scheduler:
+
+**Output**:
+```
+FrozenDict([('num_train_timesteps', 1000),
+ ('beta_start', 0.00085),
+ ('beta_end', 0.012),
+ ('beta_schedule', 'scaled_linear'),
+ ('trained_betas', None),
+ ('skip_prk_steps', True),
+ ('set_alpha_to_one', False),
+ ('steps_offset', 1),
+ ('_class_name', 'PNDMScheduler'),
+ ('_diffusers_version', '0.8.0.dev0'),
+ ('clip_sample', False)])
+```
+
+This configuration can then be used to instantiate a scheduler
+of a different class that is compatible with the pipeline. Here,
+we change the scheduler to the [`DDIMScheduler`].
+
+```python
+from diffusers import DDIMScheduler
+
+pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
+```
+
+Cool, now we can run the pipeline again to compare the generation quality.
+
+```python
+generator = torch.Generator(device="cuda").manual_seed(8)
+image = pipeline(prompt, generator=generator).images[0]
+image
+```
+
+
+
+
+
+
+
+
+## Compare schedulers
+
+So far we have tried running the stable diffusion pipeline with two schedulers: [`PNDMScheduler`] and [`DDIMScheduler`].
+A number of better schedulers have been released that can be run with much fewer steps, let's compare them here:
+
+[`LMSDiscreteScheduler`] usually leads to better results:
+
+```python
+from diffusers import LMSDiscreteScheduler
+
+pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
+
+generator = torch.Generator(device="cuda").manual_seed(8)
+image = pipeline(prompt, generator=generator).images[0]
+image
+```
+
+
+
+
+
+
+
+
+[`EulerDiscreteScheduler`] and [`EulerAncestralDiscreteScheduler`] can generate high quality results with as little as 30 steps.
+
+```python
+from diffusers import EulerDiscreteScheduler
+
+pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
+
+generator = torch.Generator(device="cuda").manual_seed(8)
+image = pipeline(prompt, generator=generator, num_inference_steps=30).images[0]
+image
+```
+
+
+
+
+
+
+
+
+and:
+
+```python
+from diffusers import EulerAncestralDiscreteScheduler
+
+pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
+
+generator = torch.Generator(device="cuda").manual_seed(8)
+image = pipeline(prompt, generator=generator, num_inference_steps=30).images[0]
+image
+```
+
+
+
+
+
+
+
+
+At the time of writing this doc [`DPMSolverMultistepScheduler`] gives arguably the best speed/quality trade-off and can be run with as little
+as 20 steps.
+
+```python
+from diffusers import DPMSolverMultistepScheduler
+
+pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+
+generator = torch.Generator(device="cuda").manual_seed(8)
+image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]
+image
+```
+
+
+
+
+
+
+
+As you can see most images look very similar and are arguably of very similar quality. It often really depends on the specific use case which scheduler to choose. A good approach is always to run multiple different
+schedulers to compare results.
diff --git a/examples/README.md b/examples/README.md
index 29872a7a16..06ce06b9e3 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -42,7 +42,7 @@ Training examples show how to pretrain or fine-tune diffusion models for a varie
| [**Text-to-Image fine-tuning**](./text_to_image) | ✅ | ✅ |
| [**Textual Inversion**](./textual_inversion) | ✅ | - | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
| [**Dreambooth**](./dreambooth) | ✅ | - | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb)
-
+| [**Reinforcement Learning for Control**](https://github.com/huggingface/diffusers/blob/main/examples/rl/run_diffusers_locomotion.py) | - | - | coming soon.
## Community
diff --git a/examples/community/README.md b/examples/community/README.md
index fd6fff79c5..5535937dca 100644
--- a/examples/community/README.md
+++ b/examples/community/README.md
@@ -15,7 +15,7 @@ If a community doesn't work as expected, please open an issue and ping the autho
| Long Prompt Weighting Stable Diffusion | **One** Stable Diffusion Pipeline without tokens length limit, and support parsing weighting in prompt. | [Long Prompt Weighting Stable Diffusion](#long-prompt-weighting-stable-diffusion) | - | [SkyTNT](https://github.com/SkyTNT) |
| Speech to Image | Using automatic-speech-recognition to transcribe text and Stable Diffusion to generate images | [Speech to Image](#speech-to-image) | - | [Mikail Duzenli](https://github.com/MikailINTech)
| Wild Card Stable Diffusion | Stable Diffusion Pipeline that supports prompts that contain wildcard terms (indicated by surrounding double underscores), with values instantiated randomly from a corresponding txt file or a dictionary of possible values | [Wildcard Stable Diffusion](#wildcard-stable-diffusion) | - | [Shyam Sudhakaran](https://github.com/shyamsn97) |
-| Composable Stable Diffusion| Stable Diffusion Pipeline that supports prompts that contain "|" in prompts (as an AND condition) and weights (separated by "|" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
+| [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain "|" in prompts (as an AND condition) and weights (separated by "|" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
| Seed Resizing Stable Diffusion| Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | - | [Mark Rich](https://github.com/MarkRich) |
| Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image| [Imagic Stable Diffusion](#imagic-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
| Multilingual Stable Diffusion| Stable Diffusion Pipeline that supports prompts in 50 different languages. | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline) | - | [Juan Carlos Piñeros](https://github.com/juancopi81) |
@@ -345,6 +345,8 @@ out = pipe(
### Composable Stable diffusion
+[Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) proposes conjunction and negation (negative prompts) operators for compositional generation with conditional diffusion models.
+
```python
import torch as th
import numpy as np
diff --git a/examples/community/imagic_stable_diffusion.py b/examples/community/imagic_stable_diffusion.py
index 0c95fb4358..d6d89283b1 100644
--- a/examples/community/imagic_stable_diffusion.py
+++ b/examples/community/imagic_stable_diffusion.py
@@ -17,7 +17,7 @@ from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
-from diffusers.utils import logging
+from diffusers.utils import PIL_INTERPOLATION, logging
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
@@ -28,7 +28,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
- image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py
index e4ee7bf3c6..8c5f5b46a7 100644
--- a/examples/community/lpw_stable_diffusion.py
+++ b/examples/community/lpw_stable_diffusion.py
@@ -12,7 +12,7 @@ from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
-from diffusers.utils import deprecate, is_accelerate_available, logging
+from diffusers.utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
@@ -358,7 +358,7 @@ def get_weighted_text_embeddings(
def preprocess_image(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
- image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
@@ -369,7 +369,7 @@ def preprocess_mask(mask):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
- mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
+ mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"])
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
diff --git a/examples/community/lpw_stable_diffusion_onnx.py b/examples/community/lpw_stable_diffusion_onnx.py
index 12e306a612..268af775a3 100644
--- a/examples/community/lpw_stable_diffusion_onnx.py
+++ b/examples/community/lpw_stable_diffusion_onnx.py
@@ -10,7 +10,7 @@ from diffusers.onnx_utils import OnnxRuntimeModel
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
-from diffusers.utils import logging
+from diffusers.utils import PIL_INTERPOLATION, logging
from transformers import CLIPFeatureExtractor, CLIPTokenizer
@@ -365,7 +365,7 @@ def get_weighted_text_embeddings(
def preprocess_image(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
- image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
return 2.0 * image - 1.0
@@ -375,7 +375,7 @@ def preprocess_mask(mask):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
- mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
+ mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"])
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md
index 3c9d04abc2..2339e2979d 100644
--- a/examples/dreambooth/README.md
+++ b/examples/dreambooth/README.md
@@ -92,7 +92,7 @@ accelerate launch train_dreambooth.py \
With the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU.
-Install `bitsandbytes` with `pip install bitsandbytes`
+To install `bitandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation).
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
diff --git a/examples/rl/README.md b/examples/rl/README.md
new file mode 100644
index 0000000000..d68f2bf780
--- /dev/null
+++ b/examples/rl/README.md
@@ -0,0 +1,19 @@
+# Overview
+
+These examples show how to run (Diffuser)[https://arxiv.org/abs/2205.09991] in Diffusers.
+There are four scripts,
+1. `run_diffuser_locomotion.py` to sample actions and run them in the environment,
+2. and `run_diffuser_gen_trajectories.py` to just sample actions from the pre-trained diffusion model.
+
+You will need some RL specific requirements to run the examples:
+
+```
+pip install -f https://download.pytorch.org/whl/torch_stable.html \
+ free-mujoco-py \
+ einops \
+ gym==0.24.1 \
+ protobuf==3.20.1 \
+ git+https://github.com/rail-berkeley/d4rl.git \
+ mediapy \
+ Pillow==9.0.0
+```
diff --git a/examples/rl/run_diffuser_gen_trajectories.py b/examples/rl/run_diffuser_gen_trajectories.py
new file mode 100644
index 0000000000..5bb068cc9f
--- /dev/null
+++ b/examples/rl/run_diffuser_gen_trajectories.py
@@ -0,0 +1,57 @@
+import d4rl # noqa
+import gym
+import tqdm
+from diffusers.experimental import ValueGuidedRLPipeline
+
+
+config = dict(
+ n_samples=64,
+ horizon=32,
+ num_inference_steps=20,
+ n_guide_steps=0,
+ scale_grad_by_std=True,
+ scale=0.1,
+ eta=0.0,
+ t_grad_cutoff=2,
+ device="cpu",
+)
+
+
+if __name__ == "__main__":
+ env_name = "hopper-medium-v2"
+ env = gym.make(env_name)
+
+ pipeline = ValueGuidedRLPipeline.from_pretrained(
+ "bglick13/hopper-medium-v2-value-function-hor32",
+ env=env,
+ )
+
+ env.seed(0)
+ obs = env.reset()
+ total_reward = 0
+ total_score = 0
+ T = 1000
+ rollout = [obs.copy()]
+ try:
+ for t in tqdm.tqdm(range(T)):
+ # Call the policy
+ denorm_actions = pipeline(obs, planning_horizon=32)
+
+ # execute action in environment
+ next_observation, reward, terminal, _ = env.step(denorm_actions)
+ score = env.get_normalized_score(total_reward)
+ # update return
+ total_reward += reward
+ total_score += score
+ print(
+ f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:"
+ f" {total_score}"
+ )
+ # save observations for rendering
+ rollout.append(next_observation.copy())
+
+ obs = next_observation
+ except KeyboardInterrupt:
+ pass
+
+ print(f"Total reward: {total_reward}")
diff --git a/examples/rl/run_diffuser_locomotion.py b/examples/rl/run_diffuser_locomotion.py
new file mode 100644
index 0000000000..e89181610b
--- /dev/null
+++ b/examples/rl/run_diffuser_locomotion.py
@@ -0,0 +1,57 @@
+import d4rl # noqa
+import gym
+import tqdm
+from diffusers.experimental import ValueGuidedRLPipeline
+
+
+config = dict(
+ n_samples=64,
+ horizon=32,
+ num_inference_steps=20,
+ n_guide_steps=2,
+ scale_grad_by_std=True,
+ scale=0.1,
+ eta=0.0,
+ t_grad_cutoff=2,
+ device="cpu",
+)
+
+
+if __name__ == "__main__":
+ env_name = "hopper-medium-v2"
+ env = gym.make(env_name)
+
+ pipeline = ValueGuidedRLPipeline.from_pretrained(
+ "bglick13/hopper-medium-v2-value-function-hor32",
+ env=env,
+ )
+
+ env.seed(0)
+ obs = env.reset()
+ total_reward = 0
+ total_score = 0
+ T = 1000
+ rollout = [obs.copy()]
+ try:
+ for t in tqdm.tqdm(range(T)):
+ # call the policy
+ denorm_actions = pipeline(obs, planning_horizon=32)
+
+ # execute action in environment
+ next_observation, reward, terminal, _ = env.step(denorm_actions)
+ score = env.get_normalized_score(total_reward)
+ # update return
+ total_reward += reward
+ total_score += score
+ print(
+ f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:"
+ f" {total_score}"
+ )
+ # save observations for rendering
+ rollout.append(next_observation.copy())
+
+ obs = next_observation
+ except KeyboardInterrupt:
+ pass
+
+ print(f"Total reward: {total_reward}")
diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py
index fc9380edcd..532ce4a741 100644
--- a/examples/textual_inversion/textual_inversion.py
+++ b/examples/textual_inversion/textual_inversion.py
@@ -12,13 +12,13 @@ import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import Dataset
-import PIL
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
+from diffusers.utils import PIL_INTERPOLATION
from huggingface_hub import HfFolder, Repository, whoami
from PIL import Image
from torchvision import transforms
@@ -260,10 +260,10 @@ class TextualInversionDataset(Dataset):
self._length = self.num_images * repeats
self.interpolation = {
- "linear": PIL.Image.LINEAR,
- "bilinear": PIL.Image.BILINEAR,
- "bicubic": PIL.Image.BICUBIC,
- "lanczos": PIL.Image.LANCZOS,
+ "linear": PIL_INTERPOLATION["linear"],
+ "bilinear": PIL_INTERPOLATION["bilinear"],
+ "bicubic": PIL_INTERPOLATION["bicubic"],
+ "lanczos": PIL_INTERPOLATION["lanczos"],
}[interpolation]
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py
index be2b7ffb54..008fe812c9 100644
--- a/examples/textual_inversion/textual_inversion_flax.py
+++ b/examples/textual_inversion/textual_inversion_flax.py
@@ -14,7 +14,6 @@ from torch.utils.data import Dataset
import jax
import jax.numpy as jnp
import optax
-import PIL
import transformers
from diffusers import (
FlaxAutoencoderKL,
@@ -24,6 +23,7 @@ from diffusers import (
FlaxUNet2DConditionModel,
)
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
+from diffusers.utils import PIL_INTERPOLATION
from flax import jax_utils
from flax.training import train_state
from flax.training.common_utils import shard
@@ -246,10 +246,10 @@ class TextualInversionDataset(Dataset):
self._length = self.num_images * repeats
self.interpolation = {
- "linear": PIL.Image.LINEAR,
- "bilinear": PIL.Image.BILINEAR,
- "bicubic": PIL.Image.BICUBIC,
- "lanczos": PIL.Image.LANCZOS,
+ "linear": PIL_INTERPOLATION["linear"],
+ "bilinear": PIL_INTERPOLATION["bilinear"],
+ "bicubic": PIL_INTERPOLATION["bicubic"],
+ "lanczos": PIL_INTERPOLATION["lanczos"],
}[interpolation]
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
diff --git a/scripts/convert_models_diffuser_to_diffusers.py b/scripts/convert_models_diffuser_to_diffusers.py
new file mode 100644
index 0000000000..9475f7da93
--- /dev/null
+++ b/scripts/convert_models_diffuser_to_diffusers.py
@@ -0,0 +1,100 @@
+import json
+import os
+
+import torch
+
+from diffusers import UNet1DModel
+
+
+os.makedirs("hub/hopper-medium-v2/unet/hor32", exist_ok=True)
+os.makedirs("hub/hopper-medium-v2/unet/hor128", exist_ok=True)
+
+os.makedirs("hub/hopper-medium-v2/value_function", exist_ok=True)
+
+
+def unet(hor):
+ if hor == 128:
+ down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
+ block_out_channels = (32, 128, 256)
+ up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D")
+
+ elif hor == 32:
+ down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
+ block_out_channels = (32, 64, 128, 256)
+ up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D")
+ model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch")
+ state_dict = model.state_dict()
+ config = dict(
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ up_block_types=up_block_types,
+ layers_per_block=1,
+ use_timestep_embedding=True,
+ out_block_type="OutConv1DBlock",
+ norm_num_groups=8,
+ downsample_each_block=False,
+ in_channels=14,
+ out_channels=14,
+ extra_in_channels=0,
+ time_embedding_type="positional",
+ flip_sin_to_cos=False,
+ freq_shift=1,
+ sample_size=65536,
+ mid_block_type="MidResTemporalBlock1D",
+ act_fn="mish",
+ )
+ hf_value_function = UNet1DModel(**config)
+ print(f"length of state dict: {len(state_dict.keys())}")
+ print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
+ mapping = dict((k, hfk) for k, hfk in zip(model.state_dict().keys(), hf_value_function.state_dict().keys()))
+ for k, v in mapping.items():
+ state_dict[v] = state_dict.pop(k)
+ hf_value_function.load_state_dict(state_dict)
+
+ torch.save(hf_value_function.state_dict(), f"hub/hopper-medium-v2/unet/hor{hor}/diffusion_pytorch_model.bin")
+ with open(f"hub/hopper-medium-v2/unet/hor{hor}/config.json", "w") as f:
+ json.dump(config, f)
+
+
+def value_function():
+ config = dict(
+ in_channels=14,
+ down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
+ up_block_types=(),
+ out_block_type="ValueFunction",
+ mid_block_type="ValueFunctionMidBlock1D",
+ block_out_channels=(32, 64, 128, 256),
+ layers_per_block=1,
+ downsample_each_block=True,
+ sample_size=65536,
+ out_channels=14,
+ extra_in_channels=0,
+ time_embedding_type="positional",
+ use_timestep_embedding=True,
+ flip_sin_to_cos=False,
+ freq_shift=1,
+ norm_num_groups=8,
+ act_fn="mish",
+ )
+
+ model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch")
+ state_dict = model
+ hf_value_function = UNet1DModel(**config)
+ print(f"length of state dict: {len(state_dict.keys())}")
+ print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
+
+ mapping = dict((k, hfk) for k, hfk in zip(state_dict.keys(), hf_value_function.state_dict().keys()))
+ for k, v in mapping.items():
+ state_dict[v] = state_dict.pop(k)
+
+ hf_value_function.load_state_dict(state_dict)
+
+ torch.save(hf_value_function.state_dict(), "hub/hopper-medium-v2/value_function/diffusion_pytorch_model.bin")
+ with open("hub/hopper-medium-v2/value_function/config.json", "w") as f:
+ json.dump(config, f)
+
+
+if __name__ == "__main__":
+ unet(32)
+ # unet(128)
+ value_function()
diff --git a/setup.py b/setup.py
index 1bb6af4b10..d0aff10da6 100644
--- a/setup.py
+++ b/setup.py
@@ -78,7 +78,7 @@ from setuptools import find_packages, setup
# 1. all dependencies should be listed here with their version requirements if any
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
_deps = [
- "Pillow<10.0", # keep the PIL.Image.Resampling deprecation away
+ "Pillow", # keep the PIL.Image.Resampling deprecation away
"accelerate>=0.11.0",
"black==22.8",
"datasets",
@@ -97,6 +97,7 @@ _deps = [
"pytest",
"pytest-timeout",
"pytest-xdist",
+ "sentencepiece>=0.1.91,!=0.1.92",
"scipy",
"regex!=2019.12.17",
"requests",
@@ -183,6 +184,7 @@ extras["test"] = deps_list(
"pytest",
"pytest-timeout",
"pytest-xdist",
+ "sentencepiece",
"scipy",
"torchvision",
"transformers"
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 19558334af..06855e3a69 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -65,6 +65,8 @@ else:
if is_torch_available() and is_transformers_available():
from .pipelines import (
+ AltDiffusionImg2ImgPipeline,
+ AltDiffusionPipeline,
CycleDiffusionPipeline,
LDMTextToImagePipeline,
StableDiffusionImg2ImgPipeline,
diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py
index fc6ac9b5b9..c4819ddc2e 100644
--- a/src/diffusers/configuration_utils.py
+++ b/src/diffusers/configuration_utils.py
@@ -29,7 +29,7 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R
from requests import HTTPError
from . import __version__
-from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
+from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, logging
logger = logging.get_logger(__name__)
@@ -37,6 +37,38 @@ logger = logging.get_logger(__name__)
_re_configuration_file = re.compile(r"config\.(.*)\.json")
+class FrozenDict(OrderedDict):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ for key, value in self.items():
+ setattr(self, key, value)
+
+ self.__frozen = True
+
+ def __delitem__(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
+
+ def setdefault(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
+
+ def pop(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
+
+ def update(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
+
+ def __setattr__(self, name, value):
+ if hasattr(self, "__frozen") and self.__frozen:
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
+ super().__setattr__(name, value)
+
+ def __setitem__(self, name, value):
+ if hasattr(self, "__frozen") and self.__frozen:
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
+ super().__setitem__(name, value)
+
+
class ConfigMixin:
r"""
Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
@@ -49,13 +81,12 @@ class ConfigMixin:
[`~ConfigMixin.save_config`] (should be overridden by parent class).
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
overridden by parent class).
- - **_compatible_classes** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
- `from_config` can be used from a class different than the one used to save the config (should be overridden
- by parent class).
+ - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by parent
+ class).
"""
config_name = None
ignore_for_config = []
- _compatible_classes = []
+ has_compatibles = False
def register_to_config(self, **kwargs):
if self.config_name is None:
@@ -104,9 +135,98 @@ class ConfigMixin:
logger.info(f"Configuration saved in {output_config_file}")
@classmethod
- def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
+ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
r"""
- Instantiate a Python class from a pre-defined JSON-file.
+ Instantiate a Python class from a config dictionary
+
+ Parameters:
+ config (`Dict[str, Any]`):
+ A config dictionary from which the Python class will be instantiated. Make sure to only load
+ configuration files of compatible classes.
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ Whether kwargs that are not consumed by the Python class should be returned or not.
+
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to update the configuration object (after it being loaded) and initiate the Python class.
+ `**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually
+ overwrite same named arguments of `config`.
+
+ Examples:
+
+ ```python
+ >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
+
+ >>> # Download scheduler from huggingface.co and cache.
+ >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
+
+ >>> # Instantiate DDIM scheduler class with same config as DDPM
+ >>> scheduler = DDIMScheduler.from_config(scheduler.config)
+
+ >>> # Instantiate PNDM scheduler class with same config as DDPM
+ >>> scheduler = PNDMScheduler.from_config(scheduler.config)
+ ```
+ """
+ # <===== TO BE REMOVED WITH DEPRECATION
+ # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
+ if "pretrained_model_name_or_path" in kwargs:
+ config = kwargs.pop("pretrained_model_name_or_path")
+
+ if config is None:
+ raise ValueError("Please make sure to provide a config as the first positional argument.")
+ # ======>
+
+ if not isinstance(config, dict):
+ deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
+ if "Scheduler" in cls.__name__:
+ deprecation_message += (
+ f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
+ " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
+ " be removed in v1.0.0."
+ )
+ elif "Model" in cls.__name__:
+ deprecation_message += (
+ f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
+ f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
+ " instead. This functionality will be removed in v1.0.0."
+ )
+ deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
+ config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
+
+ init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
+
+ # Allow dtype to be specified on initialization
+ if "dtype" in unused_kwargs:
+ init_dict["dtype"] = unused_kwargs.pop("dtype")
+
+ # Return model and optionally state and/or unused_kwargs
+ model = cls(**init_dict)
+
+ # make sure to also save config parameters that might be used for compatible classes
+ model.register_to_config(**hidden_dict)
+
+ # add hidden kwargs of compatible classes to unused_kwargs
+ unused_kwargs = {**unused_kwargs, **hidden_dict}
+
+ if return_unused_kwargs:
+ return (model, unused_kwargs)
+ else:
+ return model
+
+ @classmethod
+ def get_config_dict(cls, *args, **kwargs):
+ deprecation_message = (
+ f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
+ " removed in version v1.0.0"
+ )
+ deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
+ return cls.load_config(*args, **kwargs)
+
+ @classmethod
+ def load_config(
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ r"""
+ Instantiate a Python class from a config dictionary
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
@@ -120,10 +240,6 @@ class ConfigMixin:
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
- ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
- Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
- as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
- checkpoint with 3 labels).
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
@@ -161,33 +277,7 @@ class ConfigMixin:
use this method in a firewalled environment.
-
"""
- config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
- init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
-
- # Allow dtype to be specified on initialization
- if "dtype" in unused_kwargs:
- init_dict["dtype"] = unused_kwargs.pop("dtype")
-
- # Return model and optionally state and/or unused_kwargs
- model = cls(**init_dict)
- return_tuple = (model,)
-
- # Flax schedulers have a state, so return it.
- if cls.__name__.startswith("Flax") and hasattr(model, "create_state") and getattr(model, "has_state", False):
- state = model.create_state()
- return_tuple += (state,)
-
- if return_unused_kwargs:
- return return_tuple + (unused_kwargs,)
- else:
- return return_tuple if len(return_tuple) > 1 else model
-
- @classmethod
- def get_config_dict(
- cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
- ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
@@ -283,6 +373,9 @@ class ConfigMixin:
except (json.JSONDecodeError, UnicodeDecodeError):
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
+ if return_unused_kwargs:
+ return config_dict, kwargs
+
return config_dict
@staticmethod
@@ -291,6 +384,9 @@ class ConfigMixin:
@classmethod
def extract_init_dict(cls, config_dict, **kwargs):
+ # 0. Copy origin config dict
+ original_dict = {k: v for k, v in config_dict.items()}
+
# 1. Retrieve expected config attributes from __init__ signature
expected_keys = cls._get_init_keys(cls)
expected_keys.remove("self")
@@ -310,10 +406,11 @@ class ConfigMixin:
# load diffusers library to import compatible and original scheduler
diffusers_library = importlib.import_module(__name__.split(".")[0])
- # remove attributes from compatible classes that orig cannot expect
- compatible_classes = [getattr(diffusers_library, c, None) for c in cls._compatible_classes]
- # filter out None potentially undefined dummy classes
- compatible_classes = [c for c in compatible_classes if c is not None]
+ if cls.has_compatibles:
+ compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
+ else:
+ compatible_classes = []
+
expected_keys_comp_cls = set()
for c in compatible_classes:
expected_keys_c = cls._get_init_keys(c)
@@ -364,7 +461,10 @@ class ConfigMixin:
# 6. Define unused keyword arguments
unused_kwargs = {**config_dict, **kwargs}
- return init_dict, unused_kwargs
+ # 7. Define "hidden" config parameters that were saved for compatible classes
+ hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict and not k.startswith("_")}
+
+ return init_dict, unused_kwargs, hidden_config_dict
@classmethod
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
@@ -377,6 +477,12 @@ class ConfigMixin:
@property
def config(self) -> Dict[str, Any]:
+ """
+ Returns the config of the class as a frozen dictionary
+
+ Returns:
+ `Dict[str, Any]`: Config of the class.
+ """
return self._internal_dict
def to_json_string(self) -> str:
@@ -401,38 +507,6 @@ class ConfigMixin:
writer.write(self.to_json_string())
-class FrozenDict(OrderedDict):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- for key, value in self.items():
- setattr(self, key, value)
-
- self.__frozen = True
-
- def __delitem__(self, *args, **kwargs):
- raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
-
- def setdefault(self, *args, **kwargs):
- raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
-
- def pop(self, *args, **kwargs):
- raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
-
- def update(self, *args, **kwargs):
- raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
-
- def __setattr__(self, name, value):
- if hasattr(self, "__frozen") and self.__frozen:
- raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
- super().__setattr__(name, value)
-
- def __setitem__(self, name, value):
- if hasattr(self, "__frozen") and self.__frozen:
- raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
- super().__setitem__(name, value)
-
-
def register_to_config(init):
r"""
Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py
index 59e13da0f2..d187b79145 100644
--- a/src/diffusers/dependency_versions_table.py
+++ b/src/diffusers/dependency_versions_table.py
@@ -2,7 +2,7 @@
# 1. modify the `_deps` dict in setup.py
# 2. run `make deps_table_update``
deps = {
- "Pillow": "Pillow<10.0",
+ "Pillow": "Pillow",
"accelerate": "accelerate>=0.11.0",
"black": "black==22.8",
"datasets": "datasets",
@@ -21,6 +21,7 @@ deps = {
"pytest": "pytest",
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
+ "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
"scipy": "scipy",
"regex": "regex!=2019.12.17",
"requests": "requests",
diff --git a/src/diffusers/experimental/README.md b/src/diffusers/experimental/README.md
new file mode 100644
index 0000000000..81a9de81c7
--- /dev/null
+++ b/src/diffusers/experimental/README.md
@@ -0,0 +1,5 @@
+# 🧨 Diffusers Experimental
+
+We are adding experimental code to support novel applications and usages of the Diffusers library.
+Currently, the following experiments are supported:
+* Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model.
\ No newline at end of file
diff --git a/src/diffusers/experimental/__init__.py b/src/diffusers/experimental/__init__.py
new file mode 100644
index 0000000000..ebc8155403
--- /dev/null
+++ b/src/diffusers/experimental/__init__.py
@@ -0,0 +1 @@
+from .rl import ValueGuidedRLPipeline
diff --git a/src/diffusers/experimental/rl/__init__.py b/src/diffusers/experimental/rl/__init__.py
new file mode 100644
index 0000000000..7b338d3173
--- /dev/null
+++ b/src/diffusers/experimental/rl/__init__.py
@@ -0,0 +1 @@
+from .value_guided_sampling import ValueGuidedRLPipeline
diff --git a/src/diffusers/experimental/rl/value_guided_sampling.py b/src/diffusers/experimental/rl/value_guided_sampling.py
new file mode 100644
index 0000000000..8d5062e3d4
--- /dev/null
+++ b/src/diffusers/experimental/rl/value_guided_sampling.py
@@ -0,0 +1,129 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import torch
+
+import tqdm
+
+from ...models.unet_1d import UNet1DModel
+from ...pipeline_utils import DiffusionPipeline
+from ...utils.dummy_pt_objects import DDPMScheduler
+
+
+class ValueGuidedRLPipeline(DiffusionPipeline):
+ def __init__(
+ self,
+ value_function: UNet1DModel,
+ unet: UNet1DModel,
+ scheduler: DDPMScheduler,
+ env,
+ ):
+ super().__init__()
+ self.value_function = value_function
+ self.unet = unet
+ self.scheduler = scheduler
+ self.env = env
+ self.data = env.get_dataset()
+ self.means = dict()
+ for key in self.data.keys():
+ try:
+ self.means[key] = self.data[key].mean()
+ except:
+ pass
+ self.stds = dict()
+ for key in self.data.keys():
+ try:
+ self.stds[key] = self.data[key].std()
+ except:
+ pass
+ self.state_dim = env.observation_space.shape[0]
+ self.action_dim = env.action_space.shape[0]
+
+ def normalize(self, x_in, key):
+ return (x_in - self.means[key]) / self.stds[key]
+
+ def de_normalize(self, x_in, key):
+ return x_in * self.stds[key] + self.means[key]
+
+ def to_torch(self, x_in):
+ if type(x_in) is dict:
+ return {k: self.to_torch(v) for k, v in x_in.items()}
+ elif torch.is_tensor(x_in):
+ return x_in.to(self.unet.device)
+ return torch.tensor(x_in, device=self.unet.device)
+
+ def reset_x0(self, x_in, cond, act_dim):
+ for key, val in cond.items():
+ x_in[:, key, act_dim:] = val.clone()
+ return x_in
+
+ def run_diffusion(self, x, conditions, n_guide_steps, scale):
+ batch_size = x.shape[0]
+ y = None
+ for i in tqdm.tqdm(self.scheduler.timesteps):
+ # create batch of timesteps to pass into model
+ timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
+ for _ in range(n_guide_steps):
+ with torch.enable_grad():
+ x.requires_grad_()
+ y = self.value_function(x.permute(0, 2, 1), timesteps).sample
+ grad = torch.autograd.grad([y.sum()], [x])[0]
+
+ posterior_variance = self.scheduler._get_variance(i)
+ model_std = torch.exp(0.5 * posterior_variance)
+ grad = model_std * grad
+ grad[timesteps < 2] = 0
+ x = x.detach()
+ x = x + scale * grad
+ x = self.reset_x0(x, conditions, self.action_dim)
+ prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
+ x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
+
+ # apply conditions to the trajectory
+ x = self.reset_x0(x, conditions, self.action_dim)
+ x = self.to_torch(x)
+ return x, y
+
+ def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
+ # normalize the observations and create batch dimension
+ obs = self.normalize(obs, "observations")
+ obs = obs[None].repeat(batch_size, axis=0)
+
+ conditions = {0: self.to_torch(obs)}
+ shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
+
+ # generate initial noise and apply our conditions (to make the trajectories start at current state)
+ x1 = torch.randn(shape, device=self.unet.device)
+ x = self.reset_x0(x1, conditions, self.action_dim)
+ x = self.to_torch(x)
+
+ # run the diffusion process
+ x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
+
+ # sort output trajectories by value
+ sorted_idx = y.argsort(0, descending=True).squeeze()
+ sorted_values = x[sorted_idx]
+ actions = sorted_values[:, :, : self.action_dim]
+ actions = actions.detach().cpu().numpy()
+ denorm_actions = self.de_normalize(actions, key="actions")
+
+ # select the action with the highest value
+ if y is not None:
+ selected_index = 0
+ else:
+ # if we didn't run value guiding, select a random action
+ selected_index = np.random.randint(0, batch_size)
+ denorm_actions = denorm_actions[selected_index, 0]
+ return denorm_actions
diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py
index e8ea37970e..be9203b4d6 100644
--- a/src/diffusers/models/attention.py
+++ b/src/diffusers/models/attention.py
@@ -557,6 +557,9 @@ class CrossAttention(nn.Module):
return hidden_states
def _memory_efficient_attention_xformers(self, query, key, value):
+ query = query.contiguous()
+ key = key.contiguous()
+ value = value.contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index b09d43fc2e..0221d891f1 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -62,14 +62,21 @@ def get_timestep_embedding(
class TimestepEmbedding(nn.Module):
- def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"):
+ def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
super().__init__()
- self.linear_1 = nn.Linear(channel, time_embed_dim)
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.act = None
if act_fn == "silu":
self.act = nn.SiLU()
- self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
+ elif act_fn == "mish":
+ self.act = nn.Mish()
+
+ if out_dim is not None:
+ time_embed_dim_out = out_dim
+ else:
+ time_embed_dim_out = time_embed_dim
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
def forward(self, sample):
sample = self.linear_1(sample)
diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py
index 7bb5416adf..52d056ae96 100644
--- a/src/diffusers/models/resnet.py
+++ b/src/diffusers/models/resnet.py
@@ -5,6 +5,75 @@ import torch.nn as nn
import torch.nn.functional as F
+class Upsample1D(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+
+ Parameters:
+ channels: channels in the inputs and outputs.
+ use_conv: a bool determining if a convolution is applied.
+ use_conv_transpose:
+ out_channels:
+ """
+
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+ self.name = name
+
+ self.conv = None
+ if use_conv_transpose:
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
+ elif use_conv:
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.use_conv_transpose:
+ return self.conv(x)
+
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
+
+ if self.use_conv:
+ x = self.conv(x)
+
+ return x
+
+
+class Downsample1D(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+
+ Parameters:
+ channels: channels in the inputs and outputs.
+ use_conv: a bool determining if a convolution is applied.
+ out_channels:
+ padding:
+ """
+
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.padding = padding
+ stride = 2
+ self.name = name
+
+ if use_conv:
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
+ else:
+ assert self.channels == self.out_channels
+ self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.conv(x)
+
+
class Upsample2D(nn.Module):
"""
An upsampling layer with an optional convolution.
@@ -12,7 +81,8 @@ class Upsample2D(nn.Module):
Parameters:
channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied.
- dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions.
+ use_conv_transpose:
+ out_channels:
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
@@ -80,7 +150,8 @@ class Downsample2D(nn.Module):
Parameters:
channels: channels in the inputs and outputs.
use_conv: a bool determining if a convolution is applied.
- dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions.
+ out_channels:
+ padding:
"""
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
@@ -415,6 +486,69 @@ class Mish(torch.nn.Module):
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
+# unet_rl.py
+def rearrange_dims(tensor):
+ if len(tensor.shape) == 2:
+ return tensor[:, :, None]
+ if len(tensor.shape) == 3:
+ return tensor[:, :, None, :]
+ elif len(tensor.shape) == 4:
+ return tensor[:, :, 0, :]
+ else:
+ raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
+
+
+class Conv1dBlock(nn.Module):
+ """
+ Conv1d --> GroupNorm --> Mish
+ """
+
+ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
+ super().__init__()
+
+ self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
+ self.group_norm = nn.GroupNorm(n_groups, out_channels)
+ self.mish = nn.Mish()
+
+ def forward(self, x):
+ x = self.conv1d(x)
+ x = rearrange_dims(x)
+ x = self.group_norm(x)
+ x = rearrange_dims(x)
+ x = self.mish(x)
+ return x
+
+
+# unet_rl.py
+class ResidualTemporalBlock1D(nn.Module):
+ def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
+ super().__init__()
+ self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
+ self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
+
+ self.time_emb_act = nn.Mish()
+ self.time_emb = nn.Linear(embed_dim, out_channels)
+
+ self.residual_conv = (
+ nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
+ )
+
+ def forward(self, x, t):
+ """
+ Args:
+ x : [ batch_size x inp_channels x horizon ]
+ t : [ batch_size x embed_dim ]
+
+ returns:
+ out : [ batch_size x out_channels x horizon ]
+ """
+ t = self.time_emb_act(t)
+ t = self.time_emb(t)
+ out = self.conv_in(x) + rearrange_dims(t)
+ out = self.conv_out(out)
+ return out + self.residual_conv(x)
+
+
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
r"""Upsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py
index cc0685deb9..29d1d707f5 100644
--- a/src/diffusers/models/unet_1d.py
+++ b/src/diffusers/models/unet_1d.py
@@ -1,3 +1,17 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from dataclasses import dataclass
from typing import Optional, Tuple, Union
@@ -8,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
-from .unet_1d_blocks import get_down_block, get_mid_block, get_up_block
+from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
@dataclass
@@ -30,11 +44,11 @@ class UNet1DModel(ModelMixin, ConfigMixin):
implements for all the model (such as downloading or saving, etc.)
Parameters:
- sample_size (`int`, *optionl*): Default length of sample. Should be adaptable at runtime.
+ sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
- freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
+ freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to :
obj:`False`): Whether to flip sin to cos for fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to :
@@ -43,6 +57,13 @@ class UNet1DModel(ModelMixin, ConfigMixin):
obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(32, 32, 64)`): Tuple of block output channels.
+ mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet.
+ out_block_type (`str`, *optional*, defaults to `None`): optional output processing of UNet.
+ act_fn (`str`, *optional*, defaults to None): optional activitation function in UNet blocks.
+ norm_num_groups (`int`, *optional*, defaults to 8): group norm member count in UNet blocks.
+ layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block.
+ downsample_each_block (`int`, *optional*, defaults to False:
+ experimental feature for using a UNet without upsampling.
"""
@register_to_config
@@ -54,16 +75,20 @@ class UNet1DModel(ModelMixin, ConfigMixin):
out_channels: int = 2,
extra_in_channels: int = 0,
time_embedding_type: str = "fourier",
- freq_shift: int = 0,
flip_sin_to_cos: bool = True,
use_timestep_embedding: bool = False,
+ freq_shift: float = 0.0,
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
- mid_block_type: str = "UNetMidBlock1D",
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
+ mid_block_type: Tuple[str] = "UNetMidBlock1D",
+ out_block_type: str = None,
block_out_channels: Tuple[int] = (32, 32, 64),
+ act_fn: str = None,
+ norm_num_groups: int = 8,
+ layers_per_block: int = 1,
+ downsample_each_block: bool = False,
):
super().__init__()
-
self.sample_size = sample_size
# time
@@ -73,12 +98,19 @@ class UNet1DModel(ModelMixin, ConfigMixin):
)
timestep_input_dim = 2 * block_out_channels[0]
elif time_embedding_type == "positional":
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ self.time_proj = Timesteps(
+ block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
+ )
timestep_input_dim = block_out_channels[0]
if use_timestep_embedding:
time_embed_dim = block_out_channels[0] * 4
- self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+ self.time_mlp = TimestepEmbedding(
+ in_channels=timestep_input_dim,
+ time_embed_dim=time_embed_dim,
+ act_fn=act_fn,
+ out_dim=block_out_channels[0],
+ )
self.down_blocks = nn.ModuleList([])
self.mid_block = None
@@ -94,38 +126,66 @@ class UNet1DModel(ModelMixin, ConfigMixin):
if i == 0:
input_channel += extra_in_channels
+ is_final_block = i == len(block_out_channels) - 1
+
down_block = get_down_block(
down_block_type,
+ num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
+ temb_channels=block_out_channels[0],
+ add_downsample=not is_final_block or downsample_each_block,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = get_mid_block(
- mid_block_type=mid_block_type,
- mid_channels=block_out_channels[-1],
+ mid_block_type,
in_channels=block_out_channels[-1],
- out_channels=None,
+ mid_channels=block_out_channels[-1],
+ out_channels=block_out_channels[-1],
+ embed_dim=block_out_channels[0],
+ num_layers=layers_per_block,
+ add_downsample=downsample_each_block,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
+ if out_block_type is None:
+ final_upsample_channels = out_channels
+ else:
+ final_upsample_channels = block_out_channels[0]
+
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
- output_channel = reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else out_channels
+ output_channel = (
+ reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
+ )
+
+ is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
+ num_layers=layers_per_block,
in_channels=prev_output_channel,
out_channels=output_channel,
+ temb_channels=block_out_channels[0],
+ add_upsample=not is_final_block,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
- # TODO(PVP, Nathan) placeholder for RL application to be merged shortly
- # Totally fine to add another layer with a if statement - no need for nn.Identity here
+ # out
+ num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
+ self.out_block = get_out_block(
+ out_block_type=out_block_type,
+ num_groups_out=num_groups_out,
+ embed_dim=block_out_channels[0],
+ out_channels=out_channels,
+ act_fn=act_fn,
+ fc_dim=block_out_channels[-1] // 4,
+ )
def forward(
self,
@@ -144,12 +204,20 @@ class UNet1DModel(ModelMixin, ConfigMixin):
[`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
- # 1. time
- if len(timestep.shape) == 0:
- timestep = timestep[None]
- timestep_embed = self.time_proj(timestep)[..., None]
- timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ timestep_embed = self.time_proj(timesteps)
+ if self.config.use_timestep_embedding:
+ timestep_embed = self.time_mlp(timestep_embed)
+ else:
+ timestep_embed = timestep_embed[..., None]
+ timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
# 2. down
down_block_res_samples = ()
@@ -158,13 +226,18 @@ class UNet1DModel(ModelMixin, ConfigMixin):
down_block_res_samples += res_samples
# 3. mid
- sample = self.mid_block(sample)
+ if self.mid_block:
+ sample = self.mid_block(sample, timestep_embed)
# 4. up
for i, upsample_block in enumerate(self.up_blocks):
res_samples = down_block_res_samples[-1:]
down_block_res_samples = down_block_res_samples[:-1]
- sample = upsample_block(sample, res_samples)
+ sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)
+
+ # 5. post-process
+ if self.out_block:
+ sample = self.out_block(sample, timestep_embed)
if not return_dict:
return (sample,)
diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py
index 9009071d1e..fc758ebbb0 100644
--- a/src/diffusers/models/unet_1d_blocks.py
+++ b/src/diffusers/models/unet_1d_blocks.py
@@ -17,6 +17,256 @@ import torch
import torch.nn.functional as F
from torch import nn
+from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
+
+
+class DownResnetBlock1D(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels=None,
+ num_layers=1,
+ conv_shortcut=False,
+ temb_channels=32,
+ groups=32,
+ groups_out=None,
+ non_linearity=None,
+ time_embedding_norm="default",
+ output_scale_factor=1.0,
+ add_downsample=True,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+ self.time_embedding_norm = time_embedding_norm
+ self.add_downsample = add_downsample
+ self.output_scale_factor = output_scale_factor
+
+ if groups_out is None:
+ groups_out = groups
+
+ # there will always be at least one resnet
+ resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)]
+
+ for _ in range(num_layers):
+ resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if non_linearity == "swish":
+ self.nonlinearity = lambda x: F.silu(x)
+ elif non_linearity == "mish":
+ self.nonlinearity = nn.Mish()
+ elif non_linearity == "silu":
+ self.nonlinearity = nn.SiLU()
+ else:
+ self.nonlinearity = None
+
+ self.downsample = None
+ if add_downsample:
+ self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
+
+ def forward(self, hidden_states, temb=None):
+ output_states = ()
+
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for resnet in self.resnets[1:]:
+ hidden_states = resnet(hidden_states, temb)
+
+ output_states += (hidden_states,)
+
+ if self.nonlinearity is not None:
+ hidden_states = self.nonlinearity(hidden_states)
+
+ if self.downsample is not None:
+ hidden_states = self.downsample(hidden_states)
+
+ return hidden_states, output_states
+
+
+class UpResnetBlock1D(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels=None,
+ num_layers=1,
+ temb_channels=32,
+ groups=32,
+ groups_out=None,
+ non_linearity=None,
+ time_embedding_norm="default",
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.time_embedding_norm = time_embedding_norm
+ self.add_upsample = add_upsample
+ self.output_scale_factor = output_scale_factor
+
+ if groups_out is None:
+ groups_out = groups
+
+ # there will always be at least one resnet
+ resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
+
+ for _ in range(num_layers):
+ resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if non_linearity == "swish":
+ self.nonlinearity = lambda x: F.silu(x)
+ elif non_linearity == "mish":
+ self.nonlinearity = nn.Mish()
+ elif non_linearity == "silu":
+ self.nonlinearity = nn.SiLU()
+ else:
+ self.nonlinearity = None
+
+ self.upsample = None
+ if add_upsample:
+ self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
+
+ def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None):
+ if res_hidden_states_tuple is not None:
+ res_hidden_states = res_hidden_states_tuple[-1]
+ hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
+
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for resnet in self.resnets[1:]:
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.nonlinearity is not None:
+ hidden_states = self.nonlinearity(hidden_states)
+
+ if self.upsample is not None:
+ hidden_states = self.upsample(hidden_states)
+
+ return hidden_states
+
+
+class ValueFunctionMidBlock1D(nn.Module):
+ def __init__(self, in_channels, out_channels, embed_dim):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.embed_dim = embed_dim
+
+ self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
+ self.down1 = Downsample1D(out_channels // 2, use_conv=True)
+ self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
+ self.down2 = Downsample1D(out_channels // 4, use_conv=True)
+
+ def forward(self, x, temb=None):
+ x = self.res1(x, temb)
+ x = self.down1(x)
+ x = self.res2(x, temb)
+ x = self.down2(x)
+ return x
+
+
+class MidResTemporalBlock1D(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ embed_dim,
+ num_layers: int = 1,
+ add_downsample: bool = False,
+ add_upsample: bool = False,
+ non_linearity=None,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.add_downsample = add_downsample
+
+ # there will always be at least one resnet
+ resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
+
+ for _ in range(num_layers):
+ resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if non_linearity == "swish":
+ self.nonlinearity = lambda x: F.silu(x)
+ elif non_linearity == "mish":
+ self.nonlinearity = nn.Mish()
+ elif non_linearity == "silu":
+ self.nonlinearity = nn.SiLU()
+ else:
+ self.nonlinearity = None
+
+ self.upsample = None
+ if add_upsample:
+ self.upsample = Downsample1D(out_channels, use_conv=True)
+
+ self.downsample = None
+ if add_downsample:
+ self.downsample = Downsample1D(out_channels, use_conv=True)
+
+ if self.upsample and self.downsample:
+ raise ValueError("Block cannot downsample and upsample")
+
+ def forward(self, hidden_states, temb):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for resnet in self.resnets[1:]:
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsample:
+ hidden_states = self.upsample(hidden_states)
+ if self.downsample:
+ self.downsample = self.downsample(hidden_states)
+
+ return hidden_states
+
+
+class OutConv1DBlock(nn.Module):
+ def __init__(self, num_groups_out, out_channels, embed_dim, act_fn):
+ super().__init__()
+ self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
+ self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
+ if act_fn == "silu":
+ self.final_conv1d_act = nn.SiLU()
+ if act_fn == "mish":
+ self.final_conv1d_act = nn.Mish()
+ self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
+
+ def forward(self, hidden_states, temb=None):
+ hidden_states = self.final_conv1d_1(hidden_states)
+ hidden_states = rearrange_dims(hidden_states)
+ hidden_states = self.final_conv1d_gn(hidden_states)
+ hidden_states = rearrange_dims(hidden_states)
+ hidden_states = self.final_conv1d_act(hidden_states)
+ hidden_states = self.final_conv1d_2(hidden_states)
+ return hidden_states
+
+
+class OutValueFunctionBlock(nn.Module):
+ def __init__(self, fc_dim, embed_dim):
+ super().__init__()
+ self.final_block = nn.ModuleList(
+ [
+ nn.Linear(fc_dim + embed_dim, fc_dim // 2),
+ nn.Mish(),
+ nn.Linear(fc_dim // 2, 1),
+ ]
+ )
+
+ def forward(self, hidden_states, temb):
+ hidden_states = hidden_states.view(hidden_states.shape[0], -1)
+ hidden_states = torch.cat((hidden_states, temb), dim=-1)
+ for layer in self.final_block:
+ hidden_states = layer(hidden_states)
+
+ return hidden_states
+
_kernels = {
"linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
@@ -62,7 +312,7 @@ class Upsample1d(nn.Module):
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer("kernel", kernel_1d)
- def forward(self, hidden_states):
+ def forward(self, hidden_states, temb=None):
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
@@ -162,32 +412,6 @@ class ResConvBlock(nn.Module):
return output
-def get_down_block(down_block_type, out_channels, in_channels):
- if down_block_type == "DownBlock1D":
- return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
- elif down_block_type == "AttnDownBlock1D":
- return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
- elif down_block_type == "DownBlock1DNoSkip":
- return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
- raise ValueError(f"{down_block_type} does not exist.")
-
-
-def get_up_block(up_block_type, in_channels, out_channels):
- if up_block_type == "UpBlock1D":
- return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
- elif up_block_type == "AttnUpBlock1D":
- return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
- elif up_block_type == "UpBlock1DNoSkip":
- return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
- raise ValueError(f"{up_block_type} does not exist.")
-
-
-def get_mid_block(mid_block_type, in_channels, mid_channels, out_channels):
- if mid_block_type == "UNetMidBlock1D":
- return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
- raise ValueError(f"{mid_block_type} does not exist.")
-
-
class UNetMidBlock1D(nn.Module):
def __init__(self, mid_channels, in_channels, out_channels=None):
super().__init__()
@@ -217,7 +441,7 @@ class UNetMidBlock1D(nn.Module):
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
- def forward(self, hidden_states):
+ def forward(self, hidden_states, temb=None):
hidden_states = self.down(hidden_states)
for attn, resnet in zip(self.attentions, self.resnets):
hidden_states = resnet(hidden_states)
@@ -322,7 +546,7 @@ class AttnUpBlock1D(nn.Module):
self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic")
- def forward(self, hidden_states, res_hidden_states_tuple):
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
@@ -349,7 +573,7 @@ class UpBlock1D(nn.Module):
self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic")
- def forward(self, hidden_states, res_hidden_states_tuple):
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
@@ -374,7 +598,7 @@ class UpBlock1DNoSkip(nn.Module):
self.resnets = nn.ModuleList(resnets)
- def forward(self, hidden_states, res_hidden_states_tuple):
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
@@ -382,3 +606,63 @@ class UpBlock1DNoSkip(nn.Module):
hidden_states = resnet(hidden_states)
return hidden_states
+
+
+def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample):
+ if down_block_type == "DownResnetBlock1D":
+ return DownResnetBlock1D(
+ in_channels=in_channels,
+ num_layers=num_layers,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ )
+ elif down_block_type == "DownBlock1D":
+ return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
+ elif down_block_type == "AttnDownBlock1D":
+ return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
+ elif down_block_type == "DownBlock1DNoSkip":
+ return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
+ raise ValueError(f"{down_block_type} does not exist.")
+
+
+def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_channels, add_upsample):
+ if up_block_type == "UpResnetBlock1D":
+ return UpResnetBlock1D(
+ in_channels=in_channels,
+ num_layers=num_layers,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ )
+ elif up_block_type == "UpBlock1D":
+ return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
+ elif up_block_type == "AttnUpBlock1D":
+ return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
+ elif up_block_type == "UpBlock1DNoSkip":
+ return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels, embed_dim, add_downsample):
+ if mid_block_type == "MidResTemporalBlock1D":
+ return MidResTemporalBlock1D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ embed_dim=embed_dim,
+ add_downsample=add_downsample,
+ )
+ elif mid_block_type == "ValueFunctionMidBlock1D":
+ return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
+ elif mid_block_type == "UNetMidBlock1D":
+ return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
+ raise ValueError(f"{mid_block_type} does not exist.")
+
+
+def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim):
+ if out_block_type == "OutConv1DBlock":
+ return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
+ elif out_block_type == "ValueFunction":
+ return OutValueFunctionBlock(fc_dim, embed_dim)
+ return None
diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py
index 641c253c86..0432405760 100644
--- a/src/diffusers/models/unet_2d.py
+++ b/src/diffusers/models/unet_2d.py
@@ -51,7 +51,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to :
- obj:`False`): Whether to flip sin to cos for fourier time embedding.
+ obj:`True`): Whether to flip sin to cos for fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
types.
diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py
index 7f7f3ecd44..becae75683 100644
--- a/src/diffusers/models/unet_2d_condition.py
+++ b/src/diffusers/models/unet_2d_condition.py
@@ -60,7 +60,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
- flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py
index 4c34e64f78..54bb028139 100644
--- a/src/diffusers/pipeline_flax_utils.py
+++ b/src/diffusers/pipeline_flax_utils.py
@@ -47,7 +47,7 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES = {
"diffusers": {
"FlaxModelMixin": ["save_pretrained", "from_pretrained"],
- "FlaxSchedulerMixin": ["save_config", "from_config"],
+ "FlaxSchedulerMixin": ["save_pretrained", "from_pretrained"],
"FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"],
},
"transformers": {
@@ -280,7 +280,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
>>> from diffusers import FlaxDPMSolverMultistepScheduler
>>> model_id = "runwayml/stable-diffusion-v1-5"
- >>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_config(
+ >>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
... model_id,
... subfolder="scheduler",
... )
@@ -303,7 +303,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path):
- config_dict = cls.get_config_dict(
+ config_dict = cls.load_config(
pretrained_model_name_or_path,
cache_dir=cache_dir,
resume_download=resume_download,
@@ -349,7 +349,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
else:
cached_folder = pretrained_model_name_or_path
- config_dict = cls.get_config_dict(cached_folder)
+ config_dict = cls.load_config(cached_folder)
# 2. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
@@ -370,7 +370,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
- init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
+ init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {}
diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py
index a194f3eb34..4ab1695683 100644
--- a/src/diffusers/pipeline_utils.py
+++ b/src/diffusers/pipeline_utils.py
@@ -57,6 +57,7 @@ if is_transformers_available():
INDEX_FILE = "diffusion_pytorch_model.bin"
CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
DUMMY_MODULES_FOLDER = "diffusers.utils"
+TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils"
logger = logging.get_logger(__name__)
@@ -65,7 +66,7 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES = {
"diffusers": {
"ModelMixin": ["save_pretrained", "from_pretrained"],
- "SchedulerMixin": ["save_config", "from_config"],
+ "SchedulerMixin": ["save_pretrained", "from_pretrained"],
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
},
@@ -207,7 +208,7 @@ class DiffusionPipeline(ConfigMixin):
if torch_device is None:
return self
- module_names, _ = self.extract_init_dict(dict(self.config))
+ module_names, _, _ = self.extract_init_dict(dict(self.config))
for name in module_names.keys():
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
@@ -228,7 +229,7 @@ class DiffusionPipeline(ConfigMixin):
Returns:
`torch.device`: The torch device on which the pipeline is located.
"""
- module_names, _ = self.extract_init_dict(dict(self.config))
+ module_names, _, _ = self.extract_init_dict(dict(self.config))
for name in module_names.keys():
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
@@ -377,11 +378,11 @@ class DiffusionPipeline(ConfigMixin):
>>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
- >>> # Download pipeline, but overwrite scheduler
+ >>> # Use a different scheduler
>>> from diffusers import LMSDiscreteScheduler
- >>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
- >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler)
+ >>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
+ >>> pipeline.scheduler = scheduler
```
"""
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
@@ -428,7 +429,7 @@ class DiffusionPipeline(ConfigMixin):
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path):
- config_dict = cls.get_config_dict(
+ config_dict = cls.load_config(
pretrained_model_name_or_path,
cache_dir=cache_dir,
resume_download=resume_download,
@@ -474,7 +475,7 @@ class DiffusionPipeline(ConfigMixin):
else:
cached_folder = pretrained_model_name_or_path
- config_dict = cls.get_config_dict(cached_folder)
+ config_dict = cls.load_config(cached_folder)
# 2. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
@@ -513,7 +514,7 @@ class DiffusionPipeline(ConfigMixin):
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"])
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
- init_dict, unused_kwargs = pipeline_class.extract_init_dict(config_dict, **kwargs)
+ init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
if len(unused_kwargs) > 0:
logger.warning(f"Keyword arguments {unused_kwargs} not recognized.")
@@ -592,7 +593,10 @@ class DiffusionPipeline(ConfigMixin):
if load_method_name is None:
none_module = class_obj.__module__
- if none_module.startswith(DUMMY_MODULES_FOLDER) and "dummy" in none_module:
+ is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith(
+ TRANSFORMERS_DUMMY_MODULES_FOLDER
+ )
+ if is_dummy_path and "dummy" in none_module:
# call class_obj for nice error message of missing requirements
class_obj()
diff --git a/src/diffusers/pipelines/README.md b/src/diffusers/pipelines/README.md
index 2941660fa2..6ff40d3549 100644
--- a/src/diffusers/pipelines/README.md
+++ b/src/diffusers/pipelines/README.md
@@ -40,7 +40,7 @@ available a colab notebook to directly try them out.
| [pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | *Unconditional Image Generation* |
| [score_sde_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* |
| [score_sde_vp](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* |
-| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-to-Image Generation* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
+| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-to-Image Generation* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb)
| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Image-to-Image Text-Guided Generation* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-Guided Image Inpainting* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stochastic_karras_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | *Unconditional Image Generation* |
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index abb09605e3..c284855aac 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -15,6 +15,7 @@ else:
from ..utils.dummy_pt_objects import * # noqa F403
if is_torch_available() and is_transformers_available():
+ from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
from .latent_diffusion import LDMTextToImagePipeline
from .stable_diffusion import (
CycleDiffusionPipeline,
diff --git a/src/diffusers/pipelines/alt_diffusion/__init__.py b/src/diffusers/pipelines/alt_diffusion/__init__.py
new file mode 100644
index 0000000000..09d0d9b785
--- /dev/null
+++ b/src/diffusers/pipelines/alt_diffusion/__init__.py
@@ -0,0 +1,34 @@
+from dataclasses import dataclass
+from typing import List, Optional, Union
+
+import numpy as np
+
+import PIL
+from PIL import Image
+
+from ...utils import BaseOutput, is_torch_available, is_transformers_available
+
+
+@dataclass
+# Copied from diffusers.pipelines.stable_diffusion.__init__.StableDiffusionPipelineOutput with Stable->Alt
+class AltDiffusionPipelineOutput(BaseOutput):
+ """
+ Output class for Alt Diffusion pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ nsfw_content_detected (`List[bool]`)
+ List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, or `None` if safety checking could not be performed.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+ nsfw_content_detected: Optional[List[bool]]
+
+
+if is_transformers_available() and is_torch_available():
+ from .modeling_roberta_series import RobertaSeriesModelWithTransformation
+ from .pipeline_alt_diffusion import AltDiffusionPipeline
+ from .pipeline_alt_diffusion_img2img import AltDiffusionImg2ImgPipeline
diff --git a/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py b/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py
new file mode 100644
index 0000000000..2e92314162
--- /dev/null
+++ b/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py
@@ -0,0 +1,110 @@
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import torch
+from torch import nn
+
+from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel
+from transformers.utils import ModelOutput
+
+
+@dataclass
+class TransformationModelOutput(ModelOutput):
+ """
+ Base class for text model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
+ The text embeddings obtained by applying the projection layer to the pooler_output.
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ projection_state: Optional[torch.FloatTensor] = None
+ last_hidden_state: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+class RobertaSeriesConfig(XLMRobertaConfig):
+ def __init__(
+ self,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ project_dim=512,
+ pooler_fn="cls",
+ learn_encoder=False,
+ use_attention_mask=True,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+ self.project_dim = project_dim
+ self.pooler_fn = pooler_fn
+ self.learn_encoder = learn_encoder
+ self.use_attention_mask = use_attention_mask
+
+
+class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+ base_model_prefix = "roberta"
+ config_class = RobertaSeriesConfig
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.roberta = XLMRobertaModel(config)
+ self.transformation = nn.Linear(config.hidden_size, config.project_dim)
+ self.post_init()
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ):
+ r""" """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.base_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ projection_state = self.transformation(outputs.last_hidden_state)
+
+ return TransformationModelOutput(
+ projection_state=projection_state,
+ last_hidden_state=outputs.last_hidden_state,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
new file mode 100644
index 0000000000..01b2051db4
--- /dev/null
+++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
@@ -0,0 +1,533 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, List, Optional, Union
+
+import torch
+
+from diffusers.utils import is_accelerate_available
+from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
+
+from ...configuration_utils import FrozenDict
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+)
+from ...utils import deprecate, logging
+from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
+class AltDiffusionPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Alt Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`RobertaSeriesModelWithTransformation`]):
+ Frozen text-encoder. Alt Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.RobertaSeriesModelWithTransformation),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`XLMRobertaTokenizer`):
+ Tokenizer of class
+ [XLMRobertaTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.XLMRobertaTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: RobertaSeriesModelWithTransformation,
+ tokenizer: XLMRobertaTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ ],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None:
+ logger.warn(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ def enable_xformers_memory_efficient_attention(self):
+ r"""
+ Enable memory efficient attention as implemented in xformers.
+
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
+ time. Speed up at training time is not guaranteed.
+
+ Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
+ is used.
+ """
+ self.unet.set_use_memory_efficient_attention_xformers(True)
+
+ def disable_xformers_memory_efficient_attention(self):
+ r"""
+ Disable memory efficient attention as implemented in xformers.
+ """
+ self.unet.set_use_memory_efficient_attention_xformers(False)
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ def enable_sequential_cpu_offload(self):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device("cuda")
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+ @property
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ """
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
+
+ if not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ text_embeddings = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ text_embeddings = text_embeddings[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ uncond_embeddings = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ uncond_embeddings = uncond_embeddings[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(self, prompt, height, width, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // 8, width // 8)
+ if latents is None:
+ if device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
+ else:
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, callback_steps)
+
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_embeddings = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ text_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ for i, t in enumerate(self.progress_bar(timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 8. Post-processing
+ image = self.decode_latents(latents)
+
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
+
+ # 10. Convert to PIL
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
new file mode 100644
index 0000000000..294a43e86e
--- /dev/null
+++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
@@ -0,0 +1,580 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import torch
+
+import PIL
+from diffusers.utils import is_accelerate_available
+from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
+
+from ...configuration_utils import FrozenDict
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+)
+from ...utils import PIL_INTERPOLATION, deprecate, logging
+from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
+def preprocess(image):
+ w, h = image.size
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
+class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-guided image to image generation using Alt Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`RobertaSeriesModelWithTransformation`]):
+ Frozen text-encoder. Alt Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.RobertaSeriesModelWithTransformation),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`XLMRobertaTokenizer`):
+ Tokenizer of class
+ [XLMRobertaTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.XLMRobertaTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.__init__
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: RobertaSeriesModelWithTransformation,
+ tokenizer: XLMRobertaTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ ],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None:
+ logger.warn(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.enable_attention_slicing
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.disable_attention_slicing
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.enable_sequential_cpu_offload
+ def enable_sequential_cpu_offload(self):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device("cuda")
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline._execution_device
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.enable_xformers_memory_efficient_attention
+ def enable_xformers_memory_efficient_attention(self):
+ r"""
+ Enable memory efficient attention as implemented in xformers.
+
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
+ time. Speed up at training time is not guaranteed.
+
+ Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
+ is used.
+ """
+ self.unet.set_use_memory_efficient_attention_xformers(True)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.disable_xformers_memory_efficient_attention
+ def disable_xformers_memory_efficient_attention(self):
+ r"""
+ Disable memory efficient attention as implemented in xformers.
+ """
+ self.unet.set_use_memory_efficient_attention_xformers(False)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline._encode_prompt
+ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ """
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
+
+ if not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ text_embeddings = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ text_embeddings = text_embeddings[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = text_input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ uncond_embeddings = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ uncond_embeddings = uncond_embeddings[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ return text_embeddings
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.AltDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(self, prompt, strength, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ offset = self.scheduler.config.get("steps_offset", 0)
+ init_timestep = int(num_inference_steps * strength) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+
+ return timesteps
+
+ def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ init_image = init_image.to(device=device, dtype=dtype)
+ init_latent_dist = self.vae.encode(init_image).latent_dist
+ init_latents = init_latent_dist.sample(generator=generator)
+ init_latents = 0.18215 * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many init images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
+
+ # add noise to latents using the timesteps
+ noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ init_image: Union[torch.FloatTensor, PIL.Image.Image],
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ init_image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
+ `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter will be modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Returns:
+ [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 1. Check inputs
+ self.check_inputs(prompt, strength, callback_steps)
+
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_embeddings = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ # 4. Preprocess image
+ if isinstance(init_image, PIL.Image.Image):
+ init_image = preprocess(init_image)
+
+ # 5. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 6. Prepare latent variables
+ latents = self.prepare_latents(
+ init_image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
+ )
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 8. Denoising loop
+ for i, t in enumerate(self.progress_bar(timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ # 9. Post-processing
+ image = self.decode_latents(latents)
+
+ # 10. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
+
+ # 11. Convert to PIL
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py
index b7194664f4..c937a23003 100644
--- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py
+++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py
@@ -71,7 +71,7 @@ class DDPMPipeline(DiffusionPipeline):
"""
message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
- " DDPMScheduler.from_config(, predict_epsilon=True)`."
+ " DDPMScheduler.from_pretrained(, predict_epsilon=True)`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
index 044ff359e3..b296a4953f 100644
--- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
+++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
@@ -17,12 +17,13 @@ from ...schedulers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
+from ...utils import PIL_INTERPOLATION
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
- image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
diff --git a/src/diffusers/pipelines/stable_diffusion/README.md b/src/diffusers/pipelines/stable_diffusion/README.md
index a76e4c6682..bc30be4a7b 100644
--- a/src/diffusers/pipelines/stable_diffusion/README.md
+++ b/src/diffusers/pipelines/stable_diffusion/README.md
@@ -72,7 +72,7 @@ image.save("astronaut_rides_horse.png")
# make sure you're logged in with `huggingface-cli login`
from diffusers import StableDiffusionPipeline, DDIMScheduler
-scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
+scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
@@ -91,7 +91,7 @@ image.save("astronaut_rides_horse.png")
# make sure you're logged in with `huggingface-cli login`
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
-lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
+lms = LMSDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
@@ -120,7 +120,7 @@ from diffusers import CycleDiffusionPipeline, DDIMScheduler
# load the pipeline
# make sure you're logged in with `huggingface-cli login`
model_id_or_path = "CompVis/stable-diffusion-v1-4"
-scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler")
+scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler")
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")
# let's download an initial image
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
index 528dd33794..b5f4099292 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
@@ -19,13 +19,14 @@ import numpy as np
import torch
import PIL
+from diffusers.utils import is_accelerate_available
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler
-from ...utils import deprecate, logging
+from ...utils import PIL_INTERPOLATION, deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
@@ -36,7 +37,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
- image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
@@ -178,6 +179,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
feature_extractor=feature_extractor,
)
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
@@ -197,14 +199,33 @@ class CycleDiffusionPipeline(DiffusionPipeline):
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
- # set slice_size = `None` to disable `set_attention_slice`
+ # set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
+ def enable_sequential_cpu_offload(self):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device("cuda")
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
@@ -224,6 +245,26 @@ class CycleDiffusionPipeline(DiffusionPipeline):
return torch.device(module._hf_hook.execution_device)
return self.device
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
+ def enable_xformers_memory_efficient_attention(self):
+ r"""
+ Enable memory efficient attention as implemented in xformers.
+
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
+ time. Speed up at training time is not guaranteed.
+
+ Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
+ is used.
+ """
+ self.unet.set_use_memory_efficient_attention_xformers(True)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
+ def disable_xformers_memory_efficient_attention(self):
+ r"""
+ Disable memory efficient attention as implemented in xformers.
+ """
+ self.unet.set_use_memory_efficient_attention_xformers(False)
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
r"""
@@ -260,7 +301,17 @@ class CycleDiffusionPipeline(DiffusionPipeline):
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
- text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ text_embeddings = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ text_embeddings = text_embeddings[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
@@ -296,7 +347,17 @@ class CycleDiffusionPipeline(DiffusionPipeline):
truncation=True,
return_tensors="pt",
)
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ uncond_embeddings = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ uncond_embeddings = uncond_embeddings[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
@@ -310,6 +371,106 @@ class CycleDiffusionPipeline(DiffusionPipeline):
return text_embeddings
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
+ def check_inputs(self, prompt, strength, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ offset = self.scheduler.config.get("steps_offset", 0)
+ init_timestep = int(num_inference_steps * strength) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+
+ return timesteps
+
+ def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ init_image = init_image.to(device=device, dtype=dtype)
+ init_latent_dist = self.vae.encode(init_image).latent_dist
+ init_latents = init_latent_dist.sample(generator=generator)
+ init_latents = 0.18215 * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many init images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
+
+ # add noise to latents using the timestep
+ noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ clean_latents = init_latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents, clean_latents
+
@torch.no_grad()
def __call__(
self,
@@ -384,112 +545,43 @@ class CycleDiffusionPipeline(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
- if isinstance(prompt, str):
- batch_size = 1
- elif isinstance(prompt, list):
- batch_size = len(prompt)
- else:
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
-
- if batch_size != 1:
- raise ValueError(
- "At the moment only `batch_size=1` is supported for prompts, but you seem to have passed multiple"
- f" prompts: {prompt}. Please make sure to pass only a single prompt."
- )
-
- if strength < 0 or strength > 1:
- raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
-
- if (callback_steps is None) or (
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
- ):
- raise ValueError(
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
- f" {type(callback_steps)}."
- )
-
- # set timesteps
- self.scheduler.set_timesteps(num_inference_steps)
-
- if isinstance(init_image, PIL.Image.Image):
- init_image = preprocess(init_image)
+ # 1. Check inputs
+ self.check_inputs(prompt, strength, callback_steps)
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
-
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
+ # 3. Encode input prompt
text_embeddings = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, None)
source_text_embeddings = self._encode_prompt(
source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None
)
- # encode the init image into latents and scale the latents
- latents_dtype = text_embeddings.dtype
- init_image = init_image.to(device=self.device, dtype=latents_dtype)
- init_latent_dist = self.vae.encode(init_image).latent_dist
- init_latents = init_latent_dist.sample(generator=generator)
- init_latents = 0.18215 * init_latents
+ # 4. Preprocess image
+ if isinstance(init_image, PIL.Image.Image):
+ init_image = preprocess(init_image)
- if isinstance(prompt, str):
- prompt = [prompt]
- if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0:
- # expand init_latents for batch_size
- deprecation_message = (
- f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
- " images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note"
- " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
- " your script to pass as many init images as text prompts to suppress this warning."
- )
- deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False)
- additional_image_per_prompt = len(prompt) // init_latents.shape[0]
- init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
- elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0:
- raise ValueError(
- f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts."
- )
- else:
- init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
+ # 5. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
- # get the original timestep using init_timestep
- offset = self.scheduler.config.get("steps_offset", 0)
- init_timestep = int(num_inference_steps * strength) + offset
- init_timestep = min(init_timestep, num_inference_steps)
+ # 6. Prepare latent variables
+ latents, clean_latents = self.prepare_latents(
+ init_image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
+ )
+ source_latents = latents
- timesteps = self.scheduler.timesteps[-init_timestep]
- timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
-
- # add noise to latents using the timesteps
- noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
- clean_latents = init_latents
- init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
-
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
- # and should be between [0, 1]
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
- extra_step_kwargs = {}
-
- if not (accepts_eta and (0 < eta <= 1)):
- raise ValueError(
- "Currently, only the DDIM scheduler is supported. Please make sure that `pipeline.scheduler` is of"
- f" type {DDIMScheduler.__class__} and not {self.scheduler.__class__}."
- )
-
- extra_step_kwargs["eta"] = eta
-
- latents = init_latents
- source_latents = init_latents
-
- t_start = max(num_inference_steps - init_timestep + offset, 0)
-
- # Some schedulers like PNDM have timesteps as arrays
- # It's more optimized to move all timesteps to correct device beforehand
- timesteps = self.scheduler.timesteps[t_start:].to(self.device)
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+ generator = extra_step_kwargs.pop("generator", None)
+ # 8. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2)
@@ -551,22 +643,13 @@ class CycleDiffusionPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
- latents = 1 / 0.18215 * latents
- image = self.vae.decode(latents).sample
+ # 9. Post-processing
+ image = self.decode_latents(latents)
- image = (image / 2 + 0.5).clamp(0, 1)
- image = image.cpu().permute(0, 2, 3, 1).numpy()
-
- if self.safety_checker is not None:
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
- self.device
- )
- image, has_nsfw_concept = self.safety_checker(
- images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
- )
- else:
- has_nsfw_concept = None
+ # 10. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
+ # 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
index 483b5fd2d3..8b4f78c497 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
@@ -25,7 +25,7 @@ from ...configuration_utils import FrozenDict
from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
-from ...utils import deprecate, logging
+from ...utils import PIL_INTERPOLATION, deprecate, logging
from . import StableDiffusionPipelineOutput
@@ -35,7 +35,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
- image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
return 2.0 * image - 1.0
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
index 8e5c201319..6228824b3d 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
@@ -25,7 +25,7 @@ from ...configuration_utils import FrozenDict
from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
-from ...utils import deprecate, logging
+from ...utils import PIL_INTERPOLATION, deprecate, logging
from . import StableDiffusionPipelineOutput
@@ -44,7 +44,7 @@ def prepare_mask_and_masked_image(image, mask, latents_shape):
image_mask = np.array(mask.convert("L").resize((latents_shape[1] * 8, latents_shape[0] * 8)))
masked_image = image * (image_mask < 127.5)
- mask = mask.resize((latents_shape[1], latents_shape[0]), PIL.Image.NEAREST)
+ mask = mask.resize((latents_shape[1], latents_shape[0]), PIL_INTERPOLATION["nearest"])
mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
index 450fbbfb17..65922451f0 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
@@ -248,7 +248,17 @@ class StableDiffusionPipeline(DiffusionPipeline):
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
- text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ text_embeddings = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ text_embeddings = text_embeddings[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
@@ -284,7 +294,17 @@ class StableDiffusionPipeline(DiffusionPipeline):
truncation=True,
return_tensors="pt",
)
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ uncond_embeddings = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ uncond_embeddings = uncond_embeddings[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
@@ -298,6 +318,73 @@ class StableDiffusionPipeline(DiffusionPipeline):
return text_embeddings
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(self, prompt, height, width, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // 8, width // 8)
+ if latents is None:
+ if device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
+ else:
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
@torch.no_grad()
def __call__(
self,
@@ -371,75 +458,45 @@ class StableDiffusionPipeline(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
- if isinstance(prompt, str):
- batch_size = 1
- elif isinstance(prompt, list):
- batch_size = len(prompt)
- else:
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
- if height % 8 != 0 or width % 8 != 0:
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
-
- if (callback_steps is None) or (
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
- ):
- raise ValueError(
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
- f" {type(callback_steps)}."
- )
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, callback_steps)
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
-
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
+ # 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
- # Unlike in other pipelines, latents need to be generated in the target device
- # for 1-to-1 results reproducibility with the CompVis implementation.
- # However this currently doesn't work in `mps`.
-
- # get the initial random noise unless the user supplied it
- latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
- latents_dtype = text_embeddings.dtype
- if latents is None:
- if device.type == "mps":
- # randn does not work reproducibly on mps
- latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(device)
- else:
- latents = torch.randn(latents_shape, generator=generator, device=device, dtype=latents_dtype)
- else:
- if latents.shape != latents_shape:
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
- latents = latents.to(device)
-
- # set timesteps and move to the correct device
+ # 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
- timesteps_tensor = self.scheduler.timesteps
+ timesteps = self.scheduler.timesteps
- # scale the initial noise by the standard deviation required by the scheduler
- latents = latents * self.scheduler.init_noise_sigma
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ text_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
- # and should be between [0, 1]
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
- extra_step_kwargs = {}
- if accepts_eta:
- extra_step_kwargs["eta"] = eta
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
- # check if the scheduler accepts generator
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
- if accepts_generator:
- extra_step_kwargs["generator"] = generator
-
- for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # 7. Denoising loop
+ for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -459,22 +516,13 @@ class StableDiffusionPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
- latents = 1 / 0.18215 * latents
- image = self.vae.decode(latents).sample
+ # 8. Post-processing
+ image = self.decode_latents(latents)
- image = (image / 2 + 0.5).clamp(0, 1)
-
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
-
- if self.safety_checker is not None:
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
- image, has_nsfw_concept = self.safety_checker(
- images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
- )
- else:
- has_nsfw_concept = None
+ # 9. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
+ # 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
index 98c813eed1..4bfbc5fbcb 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
@@ -27,12 +27,13 @@ from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import (
DDIMScheduler,
+ DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
-from ...utils import deprecate, logging
+from ...utils import PIL_INTERPOLATION, deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
@@ -43,7 +44,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
- image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
@@ -78,6 +79,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
def __init__(
self,
vae: AutoencoderKL,
@@ -85,7 +87,12 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[
- DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
@@ -139,6 +146,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor=feature_extractor,
)
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
@@ -158,14 +166,16 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
- # set slice_size = `None` to disable `set_attention_slice`
+ # set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
@@ -202,6 +212,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
return torch.device(module._hf_hook.execution_device)
return self.device
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
@@ -214,6 +225,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
@@ -256,7 +268,17 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
- text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ text_embeddings = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ text_embeddings = text_embeddings[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
@@ -292,7 +314,17 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
truncation=True,
return_tensors="pt",
)
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ uncond_embeddings = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ uncond_embeddings = uncond_embeddings[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
@@ -306,6 +338,103 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
return text_embeddings
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(self, prompt, strength, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ offset = self.scheduler.config.get("steps_offset", 0)
+ init_timestep = int(num_inference_steps * strength) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+
+ return timesteps
+
+ def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+ init_image = init_image.to(device=device, dtype=dtype)
+ init_latent_dist = self.vae.encode(init_image).latent_dist
+ init_latents = init_latent_dist.sample(generator=generator)
+ init_latents = 0.18215 * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ deprecation_message = (
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
+ " images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note"
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
+ " your script to pass as many init images as text prompts to suppress this warning."
+ )
+ deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False)
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
+
+ # add noise to latents using the timesteps
+ noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype)
+
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+
+ return latents
+
@torch.no_grad()
def __call__(
self,
@@ -379,102 +508,40 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
- if isinstance(prompt, str):
- batch_size = 1
- elif isinstance(prompt, list):
- batch_size = len(prompt)
- else:
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
-
- if strength < 0 or strength > 1:
- raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
-
- if (callback_steps is None) or (
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
- ):
- raise ValueError(
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
- f" {type(callback_steps)}."
- )
+ # 1. Check inputs
+ self.check_inputs(prompt, strength, callback_steps)
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
-
- # set timesteps
- self.scheduler.set_timesteps(num_inference_steps)
-
- if isinstance(init_image, PIL.Image.Image):
- init_image = preprocess(init_image)
-
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
+ # 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
- # encode the init image into latents and scale the latents
- latents_dtype = text_embeddings.dtype
- init_image = init_image.to(device=device, dtype=latents_dtype)
- init_latent_dist = self.vae.encode(init_image).latent_dist
- init_latents = init_latent_dist.sample(generator=generator)
- init_latents = 0.18215 * init_latents
+ # 4. Preprocess image
+ if isinstance(init_image, PIL.Image.Image):
+ init_image = preprocess(init_image)
- if isinstance(prompt, str):
- prompt = [prompt]
- if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0:
- # expand init_latents for batch_size
- deprecation_message = (
- f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
- " images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note"
- " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
- " your script to pass as many init images as text prompts to suppress this warning."
- )
- deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False)
- additional_image_per_prompt = len(prompt) // init_latents.shape[0]
- init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
- elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0:
- raise ValueError(
- f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts."
- )
- else:
- init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
+ # 5. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
- # get the original timestep using init_timestep
- offset = self.scheduler.config.get("steps_offset", 0)
- init_timestep = int(num_inference_steps * strength) + offset
- init_timestep = min(init_timestep, num_inference_steps)
+ # 6. Prepare latent variables
+ latents = self.prepare_latents(
+ init_image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
+ )
- timesteps = self.scheduler.timesteps[-init_timestep]
- timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=device)
-
- # add noise to latents using the timesteps
- noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=latents_dtype)
- init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
-
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
- # and should be between [0, 1]
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
- extra_step_kwargs = {}
- if accepts_eta:
- extra_step_kwargs["eta"] = eta
-
- # check if the scheduler accepts generator
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
- if accepts_generator:
- extra_step_kwargs["generator"] = generator
-
- latents = init_latents
-
- t_start = max(num_inference_steps - init_timestep + offset, 0)
-
- # Some schedulers like PNDM have timesteps as arrays
- # It's more optimized to move all timesteps to correct device beforehand
- timesteps = self.scheduler.timesteps[t_start:].to(device)
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+ # 8. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
@@ -495,20 +562,13 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
- latents = 1 / 0.18215 * latents
- image = self.vae.decode(latents).sample
+ # 9. Post-processing
+ image = self.decode_latents(latents)
- image = (image / 2 + 0.5).clamp(0, 1)
- image = image.cpu().permute(0, 2, 3, 1).numpy()
-
- if self.safety_checker is not None:
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
- image, has_nsfw_concept = self.safety_checker(
- images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
- )
- else:
- has_nsfw_concept = None
+ # 10. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
+ # 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
index 3f08f6edae..fea2b3e5a8 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
@@ -139,6 +139,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
feature_extractor=feature_extractor,
)
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
@@ -158,6 +159,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
@@ -166,6 +168,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
@@ -183,6 +186,26 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
+ def enable_xformers_memory_efficient_attention(self):
+ r"""
+ Enable memory efficient attention as implemented in xformers.
+
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
+ time. Speed up at training time is not guaranteed.
+
+ Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
+ is used.
+ """
+ self.unet.set_use_memory_efficient_attention_xformers(True)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
+ def disable_xformers_memory_efficient_attention(self):
+ r"""
+ Disable memory efficient attention as implemented in xformers.
+ """
+ self.unet.set_use_memory_efficient_attention_xformers(False)
+
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
@@ -202,24 +225,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
return torch.device(module._hf_hook.execution_device)
return self.device
- def enable_xformers_memory_efficient_attention(self):
- r"""
- Enable memory efficient attention as implemented in xformers.
-
- When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
- time. Speed up at training time is not guaranteed.
-
- Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
- is used.
- """
- self.unet.set_use_memory_efficient_attention_xformers(True)
-
- def disable_xformers_memory_efficient_attention(self):
- r"""
- Disable memory efficient attention as implemented in xformers.
- """
- self.unet.set_use_memory_efficient_attention_xformers(False)
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
r"""
@@ -256,7 +261,17 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
- text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ text_embeddings = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ text_embeddings = text_embeddings[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
@@ -292,7 +307,17 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
truncation=True,
return_tensors="pt",
)
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ uncond_embeddings = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ uncond_embeddings = uncond_embeddings[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
@@ -306,6 +331,106 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
return text_embeddings
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
+ def check_inputs(self, prompt, height, width, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // 8, width // 8)
+ if latents is None:
+ if device.type == "mps":
+ # randn does not work reproducibly on mps
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
+ else:
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def prepare_mask_latents(
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
+ mask = mask.to(device=device, dtype=dtype)
+
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ # encode the mask image into latents space so we can concatenate it to the latents
+ masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
+ masked_image_latents = 0.18215 * masked_image_latents
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ mask = mask.repeat(batch_size, 1, 1, 1)
+ masked_image_latents = masked_image_latents.repeat(batch_size, 1, 1, 1)
+
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ return mask, masked_image_latents
+
@torch.no_grad()
def __call__(
self,
@@ -390,83 +515,59 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
(nsfw) content, according to the `safety_checker`.
"""
- if isinstance(prompt, str):
- batch_size = 1
- elif isinstance(prompt, list):
- batch_size = len(prompt)
- else:
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
-
- if height % 8 != 0 or width % 8 != 0:
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
-
- if (callback_steps is None) or (
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
- ):
- raise ValueError(
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
- f" {type(callback_steps)}."
- )
+ # 1. Check inputs
+ self.check_inputs(prompt, height, width, callback_steps)
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
-
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
+ # 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
- # get the initial random noise unless the user supplied it
- # Unlike in other pipelines, latents need to be generated in the target device
- # for 1-to-1 results reproducibility with the CompVis implementation.
- # However this currently doesn't work in `mps`.
+ # 4. Preprocess mask and image
+ if isinstance(image, PIL.Image.Image) and isinstance(mask_image, PIL.Image.Image):
+ mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
+
+ # 5. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps_tensor = self.scheduler.timesteps
+
+ # 6. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
- latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
- latents_dtype = text_embeddings.dtype
- if latents is None:
- if device.type == "mps":
- # randn does not exist on mps
- latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(device)
- else:
- latents = torch.randn(latents_shape, generator=generator, device=device, dtype=latents_dtype)
- else:
- if latents.shape != latents_shape:
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
- latents = latents.to(device)
-
- # prepare mask and masked_image
- mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
-
- # resize the mask to latents shape as we concatenate the mask to the latents
- # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
- # and half precision
- mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
- mask = mask.to(device=device, dtype=text_embeddings.dtype)
-
- masked_image = masked_image.to(device=device, dtype=text_embeddings.dtype)
-
- # encode the mask image into latents space so we can concatenate it to the latents
- masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
- masked_image_latents = 0.18215 * masked_image_latents
-
- # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
- mask = mask.repeat(batch_size * num_images_per_prompt, 1, 1, 1)
- masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 1, 1, 1)
-
- mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
- masked_image_latents = (
- torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ text_embeddings.dtype,
+ device,
+ generator,
+ latents,
)
- # aligning device to prevent device errors when concating it with the latent model input
- masked_image_latents = masked_image_latents.to(device=device, dtype=text_embeddings.dtype)
+ # 7. Prepare mask latent variables
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask,
+ masked_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ text_embeddings.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+ # 8. Check that sizes of mask, masked image and latents match
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
-
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
raise ValueError(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
@@ -476,27 +577,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
" `pipeline.unet` or your `mask_image` or `image` input."
)
- # set timesteps and move to the correct device
- self.scheduler.set_timesteps(num_inference_steps, device=device)
- timesteps_tensor = self.scheduler.timesteps
-
- # scale the initial noise by the standard deviation required by the scheduler
- latents = latents * self.scheduler.init_noise_sigma
-
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
- # and should be between [0, 1]
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
- extra_step_kwargs = {}
- if accepts_eta:
- extra_step_kwargs["eta"] = eta
-
- # check if the scheduler accepts generator
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
- if accepts_generator:
- extra_step_kwargs["generator"] = generator
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+ # 10. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
@@ -521,22 +605,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
- latents = 1 / 0.18215 * latents
- image = self.vae.decode(latents).sample
+ # 11. Post-processing
+ image = self.decode_latents(latents)
- image = (image / 2 + 0.5).clamp(0, 1)
-
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
-
- if self.safety_checker is not None:
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
- image, has_nsfw_concept = self.safety_checker(
- images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
- )
- else:
- has_nsfw_concept = None
+ # 12. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
+ # 13. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
index 612aa3c126..5c2a3e9523 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
@@ -19,14 +19,21 @@ import numpy as np
import torch
import PIL
-from tqdm.auto import tqdm
+from diffusers.utils import is_accelerate_available
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
-from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
-from ...utils import deprecate, logging
+from ...schedulers import (
+ DDIMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
+ PNDMScheduler,
+)
+from ...utils import PIL_INTERPOLATION, deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
@@ -37,7 +44,7 @@ logger = logging.get_logger(__name__)
def preprocess_image(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
- image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
@@ -48,7 +55,7 @@ def preprocess_mask(mask):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
- mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
+ mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"])
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
@@ -85,17 +92,26 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
- scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ scheduler: Union[
+ DDIMScheduler,
+ PNDMScheduler,
+ LMSDiscreteScheduler,
+ EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+ ],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
@@ -143,6 +159,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
feature_extractor=feature_extractor,
)
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
@@ -162,14 +179,53 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
- # set slice_size = `None` to disable `set_attention_slice`
+ # set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
+ def enable_sequential_cpu_offload(self):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device("cuda")
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
+ def enable_xformers_memory_efficient_attention(self):
+ r"""
+ Enable memory efficient attention as implemented in xformers.
+
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
+ time. Speed up at training time is not guaranteed.
+
+ Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
+ is used.
+ """
+ self.unet.set_use_memory_efficient_attention_xformers(True)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
+ def disable_xformers_memory_efficient_attention(self):
+ r"""
+ Disable memory efficient attention as implemented in xformers.
+ """
+ self.unet.set_use_memory_efficient_attention_xformers(False)
+
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
@@ -225,7 +281,17 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
- text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ text_embeddings = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ text_embeddings = text_embeddings[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
@@ -261,7 +327,17 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
truncation=True,
return_tensors="pt",
)
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ uncond_embeddings = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ uncond_embeddings = uncond_embeddings[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
@@ -275,6 +351,88 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
return text_embeddings
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ else:
+ has_nsfw_concept = None
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
+ def check_inputs(self, prompt, strength, callback_steps):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ offset = self.scheduler.config.get("steps_offset", 0)
+ init_timestep = int(num_inference_steps * strength) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
+ timesteps = self.scheduler.timesteps[t_start:]
+
+ return timesteps
+
+ def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator):
+ init_image = init_image.to(device=self.device, dtype=dtype)
+ init_latent_dist = self.vae.encode(init_image).latent_dist
+ init_latents = init_latent_dist.sample(generator=generator)
+ init_latents = 0.18215 * init_latents
+
+ # Expand init_latents for batch_size and num_images_per_prompt
+ init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
+ init_latents_orig = init_latents
+
+ # add noise to latents using the timesteps
+ noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=dtype)
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ latents = init_latents
+ return latents, init_latents_orig, noise
+
@torch.no_grad()
def __call__(
self,
@@ -353,98 +511,49 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
- if isinstance(prompt, str):
- batch_size = 1
- elif isinstance(prompt, list):
- batch_size = len(prompt)
- else:
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
-
- if strength < 0 or strength > 1:
- raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
-
- if (callback_steps is None) or (
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
- ):
- raise ValueError(
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
- f" {type(callback_steps)}."
- )
+ # 1. Check inputs
+ self.check_inputs(prompt, strength, callback_steps)
+ # 2. Define call parameters
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
-
- # set timesteps
- self.scheduler.set_timesteps(num_inference_steps)
-
- # preprocess image
- if not isinstance(init_image, torch.FloatTensor):
- init_image = preprocess_image(init_image)
-
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
+ # 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
- # encode the init image into latents and scale the latents
- latents_dtype = text_embeddings.dtype
- init_image = init_image.to(device=self.device, dtype=latents_dtype)
- init_latent_dist = self.vae.encode(init_image).latent_dist
- init_latents = init_latent_dist.sample(generator=generator)
- init_latents = 0.18215 * init_latents
+ # 4. Preprocess image and mask
+ if not isinstance(init_image, torch.FloatTensor):
+ init_image = preprocess_image(init_image)
- # Expand init_latents for batch_size and num_images_per_prompt
- init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
- init_latents_orig = init_latents
-
- # preprocess mask
if not isinstance(mask_image, torch.FloatTensor):
mask_image = preprocess_mask(mask_image)
- mask_image = mask_image.to(device=self.device, dtype=latents_dtype)
- mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
- # check sizes
- if not mask.shape == init_latents.shape:
- raise ValueError("The mask and init_image should be the same size!")
+ # 5. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.get_timesteps(num_inference_steps, strength, device)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
- # get the original timestep using init_timestep
- offset = self.scheduler.config.get("steps_offset", 0)
- init_timestep = int(num_inference_steps * strength) + offset
- init_timestep = min(init_timestep, num_inference_steps)
+ # 6. Prepare latent variables
+ # encode the init image into latents and scale the latents
+ latents, init_latents_orig, noise = self.prepare_latents(
+ init_image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
+ )
- timesteps = self.scheduler.timesteps[-init_timestep]
- timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
+ # 7. Prepare mask latent
+ mask = mask_image.to(device=self.device, dtype=latents.dtype)
+ mask = torch.cat([mask] * batch_size * num_images_per_prompt)
- # add noise to latents using the timesteps
- noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
- init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
- # and should be between [0, 1]
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
- extra_step_kwargs = {}
- if accepts_eta:
- extra_step_kwargs["eta"] = eta
-
- # check if the scheduler accepts generator
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
- if accepts_generator:
- extra_step_kwargs["generator"] = generator
-
- latents = init_latents
-
- t_start = max(num_inference_steps - init_timestep + offset, 0)
-
- # Some schedulers like PNDM have timesteps as arrays
- # It's more optimized to move all timesteps to correct device beforehand
- timesteps = self.scheduler.timesteps[t_start:].to(self.device)
-
- for i, t in tqdm(enumerate(timesteps)):
+ # 9. Denoising loop
+ for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -468,22 +577,13 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
- latents = 1 / 0.18215 * latents
- image = self.vae.decode(latents).sample
+ # 10. Post-processing
+ image = self.decode_latents(latents)
- image = (image / 2 + 0.5).clamp(0, 1)
- image = image.cpu().permute(0, 2, 3, 1).numpy()
-
- if self.safety_checker is not None:
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
- self.device
- )
- image, has_nsfw_concept = self.safety_checker(
- images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
- )
- else:
- has_nsfw_concept = None
+ # 11. Run safety checker
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
+ # 12. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py
index f9f84737be..4e3f252251 100644
--- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py
+++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py
@@ -116,6 +116,12 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
name: module for name, module in text_unet.named_modules() if isinstance(module, Transformer2DModel)
}
+ def _normalize_embeddings(self, encoder_output):
+ embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state)
+ embeds_pooled = encoder_output.text_embeds
+ embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True)
+ return embeds
+
def _encode_prompt(self, prompt, do_classifier_free_guidance):
r"""
Encodes the prompt into text encoder hidden states.
@@ -126,24 +132,17 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
"""
-
- def _normalize_embeddings(encoder_output):
- embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) # sum == 19677.4570
- embeds_pooled = encoder_output.text_embeds # sum == 260.2655
- embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True)
- return embeds
-
batch_size = len(prompt) if isinstance(prompt, list) else 1
if do_classifier_free_guidance:
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))
- uncond_embeddings = _normalize_embeddings(uncond_embeddings)
+ uncond_embeddings = self._normalize_embeddings(uncond_embeddings)
# get prompt text embeddings
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))
- text_embeddings = _normalize_embeddings(text_embeddings)
+ text_embeddings = self._normalize_embeddings(text_embeddings)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py
index 75cef635d0..1326b503ed 100644
--- a/src/diffusers/schedulers/scheduling_ddim.py
+++ b/src/diffusers/schedulers/scheduling_ddim.py
@@ -23,7 +23,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
-from ..utils import BaseOutput
+from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
from .scheduling_utils import SchedulerMixin
@@ -82,8 +82,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
- [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
- [`~ConfigMixin.from_config`] functions.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2010.02502
@@ -109,14 +109,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
"""
- _compatible_classes = [
- "PNDMScheduler",
- "DDPMScheduler",
- "LMSDiscreteScheduler",
- "EulerDiscreteScheduler",
- "EulerAncestralDiscreteScheduler",
- "DPMSolverMultistepScheduler",
- ]
+ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@register_to_config
def __init__(
diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py
index 590e3aac2e..ceef96a4a9 100644
--- a/src/diffusers/schedulers/scheduling_ddim_flax.py
+++ b/src/diffusers/schedulers/scheduling_ddim_flax.py
@@ -23,7 +23,12 @@ import flax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
-from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
+from .scheduling_utils_flax import (
+ _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
+ FlaxSchedulerMixin,
+ FlaxSchedulerOutput,
+ broadcast_to_shape_from_left,
+)
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
@@ -79,8 +84,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
- [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
- [`~ConfigMixin.from_config`] functions.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2010.02502
@@ -105,6 +110,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
stable diffusion.
"""
+ _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
+
@property
def has_state(self):
return True
diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py
index a19d91879c..299a06f4eb 100644
--- a/src/diffusers/schedulers/scheduling_ddpm.py
+++ b/src/diffusers/schedulers/scheduling_ddpm.py
@@ -22,7 +22,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
-from ..utils import BaseOutput, deprecate
+from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin
@@ -80,8 +80,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
- [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
- [`~ConfigMixin.from_config`] functions.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2006.11239
@@ -104,14 +104,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"""
- _compatible_classes = [
- "DDIMScheduler",
- "PNDMScheduler",
- "LMSDiscreteScheduler",
- "EulerDiscreteScheduler",
- "EulerAncestralDiscreteScheduler",
- "DPMSolverMultistepScheduler",
- ]
+ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@register_to_config
def __init__(
@@ -204,6 +197,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# for rl-diffuser https://arxiv.org/abs/2205.09991
elif variance_type == "fixed_small_log":
variance = torch.log(torch.clamp(variance, min=1e-20))
+ variance = torch.exp(0.5 * variance)
elif variance_type == "fixed_large":
variance = self.betas[t]
elif variance_type == "fixed_large_log":
@@ -248,7 +242,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"""
message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
- " DDPMScheduler.from_config(, predict_epsilon=True)`."
+ " DDPMScheduler.from_pretrained(, predict_epsilon=True)`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
@@ -301,7 +295,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance_noise = torch.randn(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
)
- variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
+ if self.variance_type == "fixed_small_log":
+ variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
+ else:
+ variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
pred_prev_sample = pred_prev_sample + variance
diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py
index f1b04a0417..480cbda73c 100644
--- a/src/diffusers/schedulers/scheduling_ddpm_flax.py
+++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py
@@ -24,7 +24,12 @@ from jax import random
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ..utils import deprecate
-from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
+from .scheduling_utils_flax import (
+ _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
+ FlaxSchedulerMixin,
+ FlaxSchedulerOutput,
+ broadcast_to_shape_from_left,
+)
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
@@ -79,8 +84,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
- [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
- [`~ConfigMixin.from_config`] functions.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2006.11239
@@ -103,6 +108,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
+ _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
+
@property
def has_state(self):
return True
@@ -221,7 +228,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
- " DDPMScheduler.from_config(, predict_epsilon=True)`."
+ " DDPMScheduler.from_pretrained(, predict_epsilon=True)`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
index d166354809..472b24637d 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
@@ -21,6 +21,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from .scheduling_utils import SchedulerMixin, SchedulerOutput
@@ -71,8 +72,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
- [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
- [`~ConfigMixin.from_config`] functions.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
@@ -116,14 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
- _compatible_classes = [
- "DDIMScheduler",
- "DDPMScheduler",
- "PNDMScheduler",
- "LMSDiscreteScheduler",
- "EulerDiscreteScheduler",
- "EulerAncestralDiscreteScheduler",
- ]
+ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@register_to_config
def __init__(
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
index c9a6d1cd5c..d6fa383534 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
@@ -23,7 +23,12 @@ import jax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
-from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
+from .scheduling_utils_flax import (
+ _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
+ FlaxSchedulerMixin,
+ FlaxSchedulerOutput,
+ broadcast_to_shape_from_left,
+)
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
@@ -96,8 +101,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
- [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
- [`~ConfigMixin.from_config`] functions.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
@@ -143,6 +148,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
+ _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
+
@property
def has_state(self):
return True
diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
index 621b5c17c0..f3abf017d9 100644
--- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
+++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
@@ -19,7 +19,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
-from ..utils import BaseOutput, logging
+from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging
from .scheduling_utils import SchedulerMixin
@@ -52,8 +52,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
- [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
- [`~ConfigMixin.from_config`] functions.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
@@ -67,14 +67,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
- _compatible_classes = [
- "DDIMScheduler",
- "DDPMScheduler",
- "LMSDiscreteScheduler",
- "PNDMScheduler",
- "EulerDiscreteScheduler",
- "DPMSolverMultistepScheduler",
- ]
+ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@register_to_config
def __init__(
diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py
index 2f9e938474..d9991bc3a0 100644
--- a/src/diffusers/schedulers/scheduling_euler_discrete.py
+++ b/src/diffusers/schedulers/scheduling_euler_discrete.py
@@ -19,7 +19,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
-from ..utils import BaseOutput, logging
+from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging
from .scheduling_utils import SchedulerMixin
@@ -53,8 +53,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
- [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
- [`~ConfigMixin.from_config`] functions.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
@@ -68,14 +68,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
- _compatible_classes = [
- "DDIMScheduler",
- "DDPMScheduler",
- "LMSDiscreteScheduler",
- "PNDMScheduler",
- "EulerAncestralDiscreteScheduler",
- "DPMSolverMultistepScheduler",
- ]
+ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@register_to_config
def __init__(
diff --git a/src/diffusers/schedulers/scheduling_ipndm.py b/src/diffusers/schedulers/scheduling_ipndm.py
index fb413a2805..e5495713a8 100644
--- a/src/diffusers/schedulers/scheduling_ipndm.py
+++ b/src/diffusers/schedulers/scheduling_ipndm.py
@@ -28,8 +28,8 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
- [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
- [`~ConfigMixin.from_config`] functions.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2202.09778
diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py
index 743f2e061c..b2eb332aed 100644
--- a/src/diffusers/schedulers/scheduling_karras_ve.py
+++ b/src/diffusers/schedulers/scheduling_karras_ve.py
@@ -56,8 +56,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
- [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
- [`~ConfigMixin.from_config`] functions.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
diff --git a/src/diffusers/schedulers/scheduling_karras_ve_flax.py b/src/diffusers/schedulers/scheduling_karras_ve_flax.py
index 78ab007954..c4e612c3cc 100644
--- a/src/diffusers/schedulers/scheduling_karras_ve_flax.py
+++ b/src/diffusers/schedulers/scheduling_karras_ve_flax.py
@@ -67,8 +67,8 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
- [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
- [`~ConfigMixin.from_config`] functions.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py
index 373c373ee0..8a9aedb41b 100644
--- a/src/diffusers/schedulers/scheduling_lms_discrete.py
+++ b/src/diffusers/schedulers/scheduling_lms_discrete.py
@@ -21,7 +21,7 @@ import torch
from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config
-from ..utils import BaseOutput
+from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
from .scheduling_utils import SchedulerMixin
@@ -52,8 +52,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
- [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
- [`~ConfigMixin.from_config`] functions.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
@@ -67,14 +67,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
- _compatible_classes = [
- "DDIMScheduler",
- "DDPMScheduler",
- "PNDMScheduler",
- "EulerDiscreteScheduler",
- "EulerAncestralDiscreteScheduler",
- "DPMSolverMultistepScheduler",
- ]
+ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@register_to_config
def __init__(
diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py
index 20982d38aa..21f25f72fa 100644
--- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py
+++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py
@@ -20,7 +20,12 @@ import jax.numpy as jnp
from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config
-from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
+from .scheduling_utils_flax import (
+ _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
+ FlaxSchedulerMixin,
+ FlaxSchedulerOutput,
+ broadcast_to_shape_from_left,
+)
@flax.struct.dataclass
@@ -49,8 +54,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
- [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
- [`~ConfigMixin.from_config`] functions.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
@@ -63,6 +68,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
"""
+ _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
+
@property
def has_state(self):
return True
diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py
index eec18af8d3..8bf0a59582 100644
--- a/src/diffusers/schedulers/scheduling_pndm.py
+++ b/src/diffusers/schedulers/scheduling_pndm.py
@@ -21,6 +21,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from .scheduling_utils import SchedulerMixin, SchedulerOutput
@@ -60,8 +61,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
- [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
- [`~ConfigMixin.from_config`] functions.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2202.09778
@@ -88,14 +89,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
"""
- _compatible_classes = [
- "DDIMScheduler",
- "DDPMScheduler",
- "LMSDiscreteScheduler",
- "EulerDiscreteScheduler",
- "EulerAncestralDiscreteScheduler",
- "DPMSolverMultistepScheduler",
- ]
+ _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@register_to_config
def __init__(
diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py
index 357ecfe046..298e62de20 100644
--- a/src/diffusers/schedulers/scheduling_pndm_flax.py
+++ b/src/diffusers/schedulers/scheduling_pndm_flax.py
@@ -23,7 +23,12 @@ import jax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
-from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
+from .scheduling_utils_flax import (
+ _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
+ FlaxSchedulerMixin,
+ FlaxSchedulerOutput,
+ broadcast_to_shape_from_left,
+)
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
@@ -87,8 +92,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
- [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
- [`~ConfigMixin.from_config`] functions.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2202.09778
@@ -114,6 +119,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
stable diffusion.
"""
+ _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
+
@property
def has_state(self):
return True
diff --git a/src/diffusers/schedulers/scheduling_repaint.py b/src/diffusers/schedulers/scheduling_repaint.py
index 1751f41c66..55625c1bfa 100644
--- a/src/diffusers/schedulers/scheduling_repaint.py
+++ b/src/diffusers/schedulers/scheduling_repaint.py
@@ -77,8 +77,8 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
- [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
- [`~ConfigMixin.from_config`] functions.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/pdf/2201.09865.pdf
diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py
index d31adbc3c6..1d436ab0cb 100644
--- a/src/diffusers/schedulers/scheduling_sde_ve.py
+++ b/src/diffusers/schedulers/scheduling_sde_ve.py
@@ -50,8 +50,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
- [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
- [`~ConfigMixin.from_config`] functions.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
diff --git a/src/diffusers/schedulers/scheduling_sde_ve_flax.py b/src/diffusers/schedulers/scheduling_sde_ve_flax.py
index d3eadede61..d1f762bc90 100644
--- a/src/diffusers/schedulers/scheduling_sde_ve_flax.py
+++ b/src/diffusers/schedulers/scheduling_sde_ve_flax.py
@@ -64,8 +64,8 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
- [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
- [`~ConfigMixin.from_config`] functions.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py
index a37a159a87..537d6f7e2a 100644
--- a/src/diffusers/schedulers/scheduling_sde_vp.py
+++ b/src/diffusers/schedulers/scheduling_sde_vp.py
@@ -29,8 +29,8 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
- [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
- [`~ConfigMixin.from_config`] functions.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
For more information, see the original paper: https://arxiv.org/abs/2011.13456
diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py
index 29bf982f6a..90ab674e38 100644
--- a/src/diffusers/schedulers/scheduling_utils.py
+++ b/src/diffusers/schedulers/scheduling_utils.py
@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import importlib
+import os
from dataclasses import dataclass
+from typing import Any, Dict, Optional, Union
import torch
@@ -38,6 +41,114 @@ class SchedulerOutput(BaseOutput):
class SchedulerMixin:
"""
Mixin containing common functions for the schedulers.
+
+ Class attributes:
+ - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
+ `from_config` can be used from a class different than the one used to save the config (should be overridden
+ by parent class).
"""
config_name = SCHEDULER_CONFIG_NAME
+ _compatibles = []
+ has_compatibles = True
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Dict[str, Any] = None,
+ subfolder: Optional[str] = None,
+ return_unused_kwargs=False,
+ **kwargs,
+ ):
+ r"""
+ Instantiate a Scheduler class from a pre-defined JSON configuration file inside a directory or Hub repo.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
+ organization name, like `google/ddpm-celebahq-256`.
+ - A path to a *directory* containing the schedluer configurations saved using
+ [`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`.
+ subfolder (`str`, *optional*):
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
+ huggingface.co or downloaded locally), you can specify the folder name here.
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ Whether kwargs that are not consumed by the Python class should be returned or not.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+
+
+
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
+ use this method in a firewalled environment.
+
+
+
+ """
+ config, kwargs = cls.load_config(
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ subfolder=subfolder,
+ return_unused_kwargs=True,
+ **kwargs,
+ )
+ return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
+ [`~SchedulerMixin.from_pretrained`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
+ """
+ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
+
+ @property
+ def compatibles(self):
+ """
+ Returns all schedulers that are compatible with this scheduler
+
+ Returns:
+ `List[SchedulerMixin]`: List of compatible schedulers
+ """
+ return self._get_compatibles()
+
+ @classmethod
+ def _get_compatibles(cls):
+ compatible_classes_str = list(set([cls.__name__] + cls._compatibles))
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
+ compatible_classes = [
+ getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
+ ]
+ return compatible_classes
diff --git a/src/diffusers/schedulers/scheduling_utils_flax.py b/src/diffusers/schedulers/scheduling_utils_flax.py
index e545cfe247..b3024ca450 100644
--- a/src/diffusers/schedulers/scheduling_utils_flax.py
+++ b/src/diffusers/schedulers/scheduling_utils_flax.py
@@ -11,15 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import importlib
+import os
from dataclasses import dataclass
-from typing import Tuple
+from typing import Any, Dict, Optional, Tuple, Union
import jax.numpy as jnp
-from ..utils import BaseOutput
+from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
+_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = ["Flax" + c for c in _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS]
@dataclass
@@ -39,9 +42,123 @@ class FlaxSchedulerOutput(BaseOutput):
class FlaxSchedulerMixin:
"""
Mixin containing common functions for the schedulers.
+
+ Class attributes:
+ - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
+ `from_config` can be used from a class different than the one used to save the config (should be overridden
+ by parent class).
"""
config_name = SCHEDULER_CONFIG_NAME
+ _compatibles = []
+ has_compatibles = True
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Dict[str, Any] = None,
+ subfolder: Optional[str] = None,
+ return_unused_kwargs=False,
+ **kwargs,
+ ):
+ r"""
+ Instantiate a Scheduler class from a pre-defined JSON-file.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
+ organization name, like `google/ddpm-celebahq-256`.
+ - A path to a *directory* containing model weights saved using [`~SchedulerMixin.save_pretrained`],
+ e.g., `./my_model_directory/`.
+ subfolder (`str`, *optional*):
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
+ huggingface.co or downloaded locally), you can specify the folder name here.
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ Whether kwargs that are not consumed by the Python class should be returned or not.
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+
+
+
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
+ use this method in a firewalled environment.
+
+
+
+ """
+ config, kwargs = cls.load_config(
+ pretrained_model_name_or_path=pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
+ )
+ scheduler, unused_kwargs = cls.from_config(config, return_unused_kwargs=True, **kwargs)
+
+ if hasattr(scheduler, "create_state") and getattr(scheduler, "has_state", False):
+ state = scheduler.create_state()
+
+ if return_unused_kwargs:
+ return scheduler, state, unused_kwargs
+
+ return scheduler, state
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
+ [`~FlaxSchedulerMixin.from_pretrained`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
+ """
+ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
+
+ @property
+ def compatibles(self):
+ """
+ Returns all schedulers that are compatible with this scheduler
+
+ Returns:
+ `List[SchedulerMixin]`: List of compatible schedulers
+ """
+ return self._get_compatibles()
+
+ @classmethod
+ def _get_compatibles(cls):
+ compatible_classes_str = list(set([cls.__name__] + cls._compatibles))
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
+ compatible_classes = [
+ getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
+ ]
+ return compatible_classes
def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray:
diff --git a/src/diffusers/schedulers/scheduling_vq_diffusion.py b/src/diffusers/schedulers/scheduling_vq_diffusion.py
index dbe320d998..91c46e6554 100644
--- a/src/diffusers/schedulers/scheduling_vq_diffusion.py
+++ b/src/diffusers/schedulers/scheduling_vq_diffusion.py
@@ -112,8 +112,8 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
- [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
- [`~ConfigMixin.from_config`] functions.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2111.14822
diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py
index a00e1f4dcd..909d878ed6 100644
--- a/src/diffusers/utils/__init__.py
+++ b/src/diffusers/utils/__init__.py
@@ -38,6 +38,7 @@ from .import_utils import (
)
from .logging import get_logger
from .outputs import BaseOutput
+from .pil_utils import PIL_INTERPOLATION
if is_torch_available():
@@ -72,3 +73,13 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
+
+_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [
+ "DDIMScheduler",
+ "DDPMScheduler",
+ "PNDMScheduler",
+ "LMSDiscreteScheduler",
+ "EulerDiscreteScheduler",
+ "EulerAncestralDiscreteScheduler",
+ "DPMSolverMultistepScheduler",
+]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index 63e8a60f74..92c163ba74 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -4,6 +4,36 @@
from ..utils import DummyObject, requires_backends
+class AltDiffusionImg2ImgPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AltDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class CycleDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/src/diffusers/utils/pil_utils.py b/src/diffusers/utils/pil_utils.py
new file mode 100644
index 0000000000..39d0a15a4e
--- /dev/null
+++ b/src/diffusers/utils/pil_utils.py
@@ -0,0 +1,21 @@
+import PIL.Image
+import PIL.ImageOps
+from packaging import version
+
+
+if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.Resampling.BILINEAR,
+ "bilinear": PIL.Image.Resampling.BILINEAR,
+ "bicubic": PIL.Image.Resampling.BICUBIC,
+ "lanczos": PIL.Image.Resampling.LANCZOS,
+ "nearest": PIL.Image.Resampling.NEAREST,
+ }
+else:
+ PIL_INTERPOLATION = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ "nearest": PIL.Image.NEAREST,
+ }
diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py
index c274ce4192..089d935651 100644
--- a/tests/models/test_models_unet_1d.py
+++ b/tests/models/test_models_unet_1d.py
@@ -18,13 +18,120 @@ import unittest
import torch
from diffusers import UNet1DModel
-from diffusers.utils import slow, torch_device
+from diffusers.utils import floats_tensor, slow, torch_device
+
+from ..test_modeling_common import ModelTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
-class UnetModel1DTests(unittest.TestCase):
+class UNet1DModelTests(ModelTesterMixin, unittest.TestCase):
+ model_class = UNet1DModel
+
+ @property
+ def dummy_input(self):
+ batch_size = 4
+ num_features = 14
+ seq_len = 16
+
+ noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
+ time_step = torch.tensor([10] * batch_size).to(torch_device)
+
+ return {"sample": noise, "timestep": time_step}
+
+ @property
+ def input_shape(self):
+ return (4, 14, 16)
+
+ @property
+ def output_shape(self):
+ return (4, 14, 16)
+
+ def test_ema_training(self):
+ pass
+
+ def test_training(self):
+ pass
+
+ @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
+ def test_determinism(self):
+ super().test_determinism()
+
+ @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
+ def test_outputs_equivalence(self):
+ super().test_outputs_equivalence()
+
+ @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
+ def test_from_pretrained_save_pretrained(self):
+ super().test_from_pretrained_save_pretrained()
+
+ @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
+ def test_model_from_pretrained(self):
+ super().test_model_from_pretrained()
+
+ @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
+ def test_output(self):
+ super().test_output()
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "block_out_channels": (32, 64, 128, 256),
+ "in_channels": 14,
+ "out_channels": 14,
+ "time_embedding_type": "positional",
+ "use_timestep_embedding": True,
+ "flip_sin_to_cos": False,
+ "freq_shift": 1.0,
+ "out_block_type": "OutConv1DBlock",
+ "mid_block_type": "MidResTemporalBlock1D",
+ "down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
+ "up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"),
+ "act_fn": "mish",
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
+ def test_from_pretrained_hub(self):
+ model, loading_info = UNet1DModel.from_pretrained(
+ "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet"
+ )
+ self.assertIsNotNone(model)
+ self.assertEqual(len(loading_info["missing_keys"]), 0)
+
+ model.to(torch_device)
+ image = model(**self.dummy_input)
+
+ assert image is not None, "Make sure output is not None"
+
+ @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
+ def test_output_pretrained(self):
+ model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet")
+ torch.manual_seed(0)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(0)
+
+ num_features = model.in_channels
+ seq_len = 16
+ noise = torch.randn((1, seq_len, num_features)).permute(
+ 0, 2, 1
+ ) # match original, we can update values and remove
+ time_step = torch.full((num_features,), 0)
+
+ with torch.no_grad():
+ output = model(noise, time_step).sample.permute(0, 2, 1)
+
+ output_slice = output[0, -3:, -3:].flatten()
+ # fmt: off
+ expected_output_slice = torch.tensor([-2.137172, 1.1426016, 0.3688687, -0.766922, 0.7303146, 0.11038864, -0.4760633, 0.13270172, 0.02591348])
+ # fmt: on
+ self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
+
+ def test_forward_with_norm_groups(self):
+ # Not implemented yet for this UNet
+ pass
+
@slow
def test_unet_1d_maestro(self):
model_id = "harmonai/maestro-150k"
@@ -43,3 +150,127 @@ class UnetModel1DTests(unittest.TestCase):
assert (output_sum - 224.0896).abs() < 4e-2
assert (output_max - 0.0607).abs() < 4e-4
+
+
+class UNetRLModelTests(ModelTesterMixin, unittest.TestCase):
+ model_class = UNet1DModel
+
+ @property
+ def dummy_input(self):
+ batch_size = 4
+ num_features = 14
+ seq_len = 16
+
+ noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
+ time_step = torch.tensor([10] * batch_size).to(torch_device)
+
+ return {"sample": noise, "timestep": time_step}
+
+ @property
+ def input_shape(self):
+ return (4, 14, 16)
+
+ @property
+ def output_shape(self):
+ return (4, 14, 1)
+
+ @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
+ def test_determinism(self):
+ super().test_determinism()
+
+ @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
+ def test_outputs_equivalence(self):
+ super().test_outputs_equivalence()
+
+ @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
+ def test_from_pretrained_save_pretrained(self):
+ super().test_from_pretrained_save_pretrained()
+
+ @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
+ def test_model_from_pretrained(self):
+ super().test_model_from_pretrained()
+
+ @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
+ def test_output(self):
+ # UNetRL is a value-function is different output shape
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ output = model(**inputs_dict)
+
+ if isinstance(output, dict):
+ output = output.sample
+
+ self.assertIsNotNone(output)
+ expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1))
+ self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
+
+ def test_ema_training(self):
+ pass
+
+ def test_training(self):
+ pass
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "in_channels": 14,
+ "out_channels": 14,
+ "down_block_types": ["DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"],
+ "up_block_types": [],
+ "out_block_type": "ValueFunction",
+ "mid_block_type": "ValueFunctionMidBlock1D",
+ "block_out_channels": [32, 64, 128, 256],
+ "layers_per_block": 1,
+ "downsample_each_block": True,
+ "use_timestep_embedding": True,
+ "freq_shift": 1.0,
+ "flip_sin_to_cos": False,
+ "time_embedding_type": "positional",
+ "act_fn": "mish",
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
+ def test_from_pretrained_hub(self):
+ value_function, vf_loading_info = UNet1DModel.from_pretrained(
+ "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
+ )
+ self.assertIsNotNone(value_function)
+ self.assertEqual(len(vf_loading_info["missing_keys"]), 0)
+
+ value_function.to(torch_device)
+ image = value_function(**self.dummy_input)
+
+ assert image is not None, "Make sure output is not None"
+
+ @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
+ def test_output_pretrained(self):
+ value_function, vf_loading_info = UNet1DModel.from_pretrained(
+ "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
+ )
+ torch.manual_seed(0)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(0)
+
+ num_features = value_function.in_channels
+ seq_len = 14
+ noise = torch.randn((1, seq_len, num_features)).permute(
+ 0, 2, 1
+ ) # match original, we can update values and remove
+ time_step = torch.full((num_features,), 0)
+
+ with torch.no_grad():
+ output = value_function(noise, time_step).sample
+
+ # fmt: off
+ expected_output_slice = torch.tensor([165.25] * seq_len)
+ # fmt: on
+ self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3))
+
+ def test_forward_with_norm_groups(self):
+ # Not implemented yet for this UNet
+ pass
diff --git a/tests/pipelines/altdiffusion/__init__.py b/tests/pipelines/altdiffusion/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/pipelines/altdiffusion/test_alt_diffusion.py b/tests/pipelines/altdiffusion/test_alt_diffusion.py
new file mode 100644
index 0000000000..b743d100ce
--- /dev/null
+++ b/tests/pipelines/altdiffusion/test_alt_diffusion.py
@@ -0,0 +1,347 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import random
+import unittest
+
+import numpy as np
+import torch
+
+from diffusers import AltDiffusionPipeline, AutoencoderKL, DDIMScheduler, PNDMScheduler, UNet2DConditionModel
+from diffusers.pipelines.alt_diffusion.modeling_roberta_series import (
+ RobertaSeriesConfig,
+ RobertaSeriesModelWithTransformation,
+)
+from diffusers.utils import floats_tensor, slow, torch_device
+from diffusers.utils.testing_utils import require_torch_gpu
+from transformers import XLMRobertaTokenizer
+
+from ...test_pipelines_common import PipelineTesterMixin
+
+
+torch.backends.cuda.matmul.allow_tf32 = False
+
+
+class AltDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ def tearDown(self):
+ # clean up the VRAM after each test
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ @property
+ def dummy_image(self):
+ batch_size = 1
+ num_channels = 3
+ sizes = (32, 32)
+
+ image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
+ return image
+
+ @property
+ def dummy_cond_unet(self):
+ torch.manual_seed(0)
+ model = UNet2DConditionModel(
+ block_out_channels=(32, 64),
+ layers_per_block=2,
+ sample_size=32,
+ in_channels=4,
+ out_channels=4,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ cross_attention_dim=32,
+ )
+ return model
+
+ @property
+ def dummy_cond_unet_inpaint(self):
+ torch.manual_seed(0)
+ model = UNet2DConditionModel(
+ block_out_channels=(32, 64),
+ layers_per_block=2,
+ sample_size=32,
+ in_channels=9,
+ out_channels=4,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ cross_attention_dim=32,
+ )
+ return model
+
+ @property
+ def dummy_vae(self):
+ torch.manual_seed(0)
+ model = AutoencoderKL(
+ block_out_channels=[32, 64],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ )
+ return model
+
+ @property
+ def dummy_text_encoder(self):
+ torch.manual_seed(0)
+ config = RobertaSeriesConfig(
+ hidden_size=32,
+ project_dim=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ vocab_size=5002,
+ )
+ return RobertaSeriesModelWithTransformation(config)
+
+ @property
+ def dummy_extractor(self):
+ def extract(*args, **kwargs):
+ class Out:
+ def __init__(self):
+ self.pixel_values = torch.ones([0])
+
+ def to(self, device):
+ self.pixel_values.to(device)
+ return self
+
+ return Out()
+
+ return extract
+
+ def test_alt_diffusion_ddim(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ unet = self.dummy_cond_unet
+ scheduler = DDIMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ clip_sample=False,
+ set_alpha_to_one=False,
+ )
+
+ vae = self.dummy_vae
+ bert = self.dummy_text_encoder
+ tokenizer = XLMRobertaTokenizer.from_pretrained("hf-internal-testing/tiny-xlm-roberta")
+ tokenizer.model_max_length = 77
+
+ # make sure here that pndm scheduler skips prk
+ alt_pipe = AltDiffusionPipeline(
+ unet=unet,
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=bert,
+ tokenizer=tokenizer,
+ safety_checker=None,
+ feature_extractor=self.dummy_extractor,
+ )
+ alt_pipe = alt_pipe.to(device)
+ alt_pipe.set_progress_bar_config(disable=None)
+
+ prompt = "A photo of an astronaut"
+
+ generator = torch.Generator(device=device).manual_seed(0)
+ output = alt_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
+ image = output.images
+
+ generator = torch.Generator(device=device).manual_seed(0)
+ image_from_tuple = alt_pipe(
+ [prompt],
+ generator=generator,
+ guidance_scale=6.0,
+ num_inference_steps=2,
+ output_type="np",
+ return_dict=False,
+ )[0]
+
+ image_slice = image[0, -3:, -3:, -1]
+ image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 128, 128, 3)
+ expected_slice = np.array(
+ [0.49249017, 0.46064827, 0.4790093, 0.50883967, 0.4811985, 0.51540506, 0.5084924, 0.4860553, 0.47318557]
+ )
+
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+
+ def test_alt_diffusion_pndm(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ unet = self.dummy_cond_unet
+ scheduler = PNDMScheduler(skip_prk_steps=True)
+ vae = self.dummy_vae
+ bert = self.dummy_text_encoder
+ tokenizer = XLMRobertaTokenizer.from_pretrained("hf-internal-testing/tiny-xlm-roberta")
+ tokenizer.model_max_length = 77
+
+ # make sure here that pndm scheduler skips prk
+ alt_pipe = AltDiffusionPipeline(
+ unet=unet,
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=bert,
+ tokenizer=tokenizer,
+ safety_checker=None,
+ feature_extractor=self.dummy_extractor,
+ )
+ alt_pipe = alt_pipe.to(device)
+ alt_pipe.set_progress_bar_config(disable=None)
+
+ prompt = "A painting of a squirrel eating a burger"
+ generator = torch.Generator(device=device).manual_seed(0)
+ output = alt_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
+
+ image = output.images
+
+ generator = torch.Generator(device=device).manual_seed(0)
+ image_from_tuple = alt_pipe(
+ [prompt],
+ generator=generator,
+ guidance_scale=6.0,
+ num_inference_steps=2,
+ output_type="np",
+ return_dict=False,
+ )[0]
+
+ image_slice = image[0, -3:, -3:, -1]
+ image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 128, 128, 3)
+ expected_slice = np.array(
+ [0.4786532, 0.45791715, 0.47507674, 0.50763345, 0.48375353, 0.515062, 0.51244247, 0.48673993, 0.47105807]
+ )
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+
+ @unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
+ def test_alt_diffusion_fp16(self):
+ """Test that stable diffusion works with fp16"""
+ unet = self.dummy_cond_unet
+ scheduler = PNDMScheduler(skip_prk_steps=True)
+ vae = self.dummy_vae
+ bert = self.dummy_text_encoder
+ tokenizer = XLMRobertaTokenizer.from_pretrained("hf-internal-testing/tiny-xlm-roberta")
+ tokenizer.model_max_length = 77
+
+ # put models in fp16
+ unet = unet.half()
+ vae = vae.half()
+ bert = bert.half()
+
+ # make sure here that pndm scheduler skips prk
+ alt_pipe = AltDiffusionPipeline(
+ unet=unet,
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=bert,
+ tokenizer=tokenizer,
+ safety_checker=None,
+ feature_extractor=self.dummy_extractor,
+ )
+ alt_pipe = alt_pipe.to(torch_device)
+ alt_pipe.set_progress_bar_config(disable=None)
+
+ prompt = "A painting of a squirrel eating a burger"
+ generator = torch.Generator(device=torch_device).manual_seed(0)
+ image = alt_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images
+
+ assert image.shape == (1, 128, 128, 3)
+
+
+@slow
+@require_torch_gpu
+class AltDiffusionPipelineIntegrationTests(unittest.TestCase):
+ def tearDown(self):
+ # clean up the VRAM after each test
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_alt_diffusion(self):
+ # make sure here that pndm scheduler skips prk
+ alt_pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion", safety_checker=None)
+ alt_pipe = alt_pipe.to(torch_device)
+ alt_pipe.set_progress_bar_config(disable=None)
+
+ prompt = "A painting of a squirrel eating a burger"
+ generator = torch.Generator(device=torch_device).manual_seed(0)
+ with torch.autocast("cuda"):
+ output = alt_pipe(
+ [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np"
+ )
+
+ image = output.images
+
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 512, 512, 3)
+ expected_slice = np.array(
+ [0.8720703, 0.87109375, 0.87402344, 0.87109375, 0.8779297, 0.8925781, 0.8823242, 0.8808594, 0.8613281]
+ )
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
+ def test_alt_diffusion_fast_ddim(self):
+ scheduler = DDIMScheduler.from_pretrained("BAAI/AltDiffusion", subfolder="scheduler")
+
+ alt_pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion", scheduler=scheduler, safety_checker=None)
+ alt_pipe = alt_pipe.to(torch_device)
+ alt_pipe.set_progress_bar_config(disable=None)
+
+ prompt = "A painting of a squirrel eating a burger"
+ generator = torch.Generator(device=torch_device).manual_seed(0)
+
+ with torch.autocast("cuda"):
+ output = alt_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy")
+ image = output.images
+
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 512, 512, 3)
+ expected_slice = np.array(
+ [0.9267578, 0.9301758, 0.9013672, 0.9345703, 0.92578125, 0.94433594, 0.9423828, 0.9423828, 0.9160156]
+ )
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
+ def test_alt_diffusion_text2img_pipeline_fp16(self):
+ torch.cuda.reset_peak_memory_stats()
+ model_id = "BAAI/AltDiffusion"
+ pipe = AltDiffusionPipeline.from_pretrained(
+ model_id, revision="fp16", torch_dtype=torch.float16, safety_checker=None
+ )
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ prompt = "a photograph of an astronaut riding a horse"
+
+ generator = torch.Generator(device=torch_device).manual_seed(0)
+ output_chunked = pipe(
+ [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
+ )
+ image_chunked = output_chunked.images
+
+ generator = torch.Generator(device=torch_device).manual_seed(0)
+ with torch.autocast(torch_device):
+ output = pipe(
+ [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
+ )
+ image = output.images
+
+ # Make sure results are close enough
+ diff = np.abs(image_chunked.flatten() - image.flatten())
+ # They ARE different since ops are not run always at the same precision
+ # however, they should be extremely close.
+ assert diff.mean() < 2e-2
diff --git a/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py b/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py
new file mode 100644
index 0000000000..0dab14b317
--- /dev/null
+++ b/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py
@@ -0,0 +1,256 @@
+# coding=utf-8
+# Copyright 2022 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import random
+import unittest
+
+import numpy as np
+import torch
+
+from diffusers import AltDiffusionImg2ImgPipeline, AutoencoderKL, PNDMScheduler, UNet2DConditionModel
+from diffusers.pipelines.alt_diffusion.modeling_roberta_series import (
+ RobertaSeriesConfig,
+ RobertaSeriesModelWithTransformation,
+)
+from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
+from diffusers.utils.testing_utils import require_torch_gpu
+from transformers import XLMRobertaTokenizer
+
+from ...test_pipelines_common import PipelineTesterMixin
+
+
+torch.backends.cuda.matmul.allow_tf32 = False
+
+
+class AltDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ def tearDown(self):
+ # clean up the VRAM after each test
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ @property
+ def dummy_image(self):
+ batch_size = 1
+ num_channels = 3
+ sizes = (32, 32)
+
+ image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
+ return image
+
+ @property
+ def dummy_cond_unet(self):
+ torch.manual_seed(0)
+ model = UNet2DConditionModel(
+ block_out_channels=(32, 64),
+ layers_per_block=2,
+ sample_size=32,
+ in_channels=4,
+ out_channels=4,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ cross_attention_dim=32,
+ )
+ return model
+
+ @property
+ def dummy_vae(self):
+ torch.manual_seed(0)
+ model = AutoencoderKL(
+ block_out_channels=[32, 64],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ )
+ return model
+
+ @property
+ def dummy_text_encoder(self):
+ torch.manual_seed(0)
+ config = RobertaSeriesConfig(
+ hidden_size=32,
+ project_dim=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=5006,
+ )
+ return RobertaSeriesModelWithTransformation(config)
+
+ @property
+ def dummy_extractor(self):
+ def extract(*args, **kwargs):
+ class Out:
+ def __init__(self):
+ self.pixel_values = torch.ones([0])
+
+ def to(self, device):
+ self.pixel_values.to(device)
+ return self
+
+ return Out()
+
+ return extract
+
+ def test_stable_diffusion_img2img_default_case(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ unet = self.dummy_cond_unet
+ scheduler = PNDMScheduler(skip_prk_steps=True)
+ vae = self.dummy_vae
+ bert = self.dummy_text_encoder
+ tokenizer = XLMRobertaTokenizer.from_pretrained("hf-internal-testing/tiny-xlm-roberta")
+ tokenizer.model_max_length = 77
+
+ init_image = self.dummy_image.to(device)
+
+ # make sure here that pndm scheduler skips prk
+ alt_pipe = AltDiffusionImg2ImgPipeline(
+ unet=unet,
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=bert,
+ tokenizer=tokenizer,
+ safety_checker=None,
+ feature_extractor=self.dummy_extractor,
+ )
+ alt_pipe = alt_pipe.to(device)
+ alt_pipe.set_progress_bar_config(disable=None)
+
+ prompt = "A painting of a squirrel eating a burger"
+ generator = torch.Generator(device=device).manual_seed(0)
+ output = alt_pipe(
+ [prompt],
+ generator=generator,
+ guidance_scale=6.0,
+ num_inference_steps=2,
+ output_type="np",
+ init_image=init_image,
+ )
+
+ image = output.images
+
+ generator = torch.Generator(device=device).manual_seed(0)
+ image_from_tuple = alt_pipe(
+ [prompt],
+ generator=generator,
+ guidance_scale=6.0,
+ num_inference_steps=2,
+ output_type="np",
+ init_image=init_image,
+ return_dict=False,
+ )[0]
+
+ image_slice = image[0, -3:, -3:, -1]
+ image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 32, 32, 3)
+ expected_slice = np.array(
+ [0.41293705, 0.38656747, 0.40876025, 0.4782187, 0.4656803, 0.41394007, 0.4142093, 0.47150758, 0.4570448]
+ )
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1.5e-3
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1.5e-3
+
+ @unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
+ def test_stable_diffusion_img2img_fp16(self):
+ """Test that stable diffusion img2img works with fp16"""
+ unet = self.dummy_cond_unet
+ scheduler = PNDMScheduler(skip_prk_steps=True)
+ vae = self.dummy_vae
+ bert = self.dummy_text_encoder
+ tokenizer = XLMRobertaTokenizer.from_pretrained("hf-internal-testing/tiny-xlm-roberta")
+ tokenizer.model_max_length = 77
+
+ init_image = self.dummy_image.to(torch_device)
+
+ # put models in fp16
+ unet = unet.half()
+ vae = vae.half()
+ bert = bert.half()
+
+ # make sure here that pndm scheduler skips prk
+ alt_pipe = AltDiffusionImg2ImgPipeline(
+ unet=unet,
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=bert,
+ tokenizer=tokenizer,
+ safety_checker=None,
+ feature_extractor=self.dummy_extractor,
+ )
+ alt_pipe = alt_pipe.to(torch_device)
+ alt_pipe.set_progress_bar_config(disable=None)
+
+ prompt = "A painting of a squirrel eating a burger"
+ generator = torch.Generator(device=torch_device).manual_seed(0)
+ image = alt_pipe(
+ [prompt],
+ generator=generator,
+ num_inference_steps=2,
+ output_type="np",
+ init_image=init_image,
+ ).images
+
+ assert image.shape == (1, 32, 32, 3)
+
+
+@slow
+@require_torch_gpu
+class AltDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
+ def tearDown(self):
+ # clean up the VRAM after each test
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_stable_diffusion_img2img_pipeline_default(self):
+ init_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ "/img2img/sketch-mountains-input.jpg"
+ )
+ init_image = init_image.resize((768, 512))
+ expected_image = load_numpy(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/fantasy_landscape_alt.npy"
+ )
+
+ model_id = "BAAI/AltDiffusion"
+ pipe = AltDiffusionImg2ImgPipeline.from_pretrained(
+ model_id,
+ safety_checker=None,
+ )
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.enable_attention_slicing()
+
+ prompt = "A fantasy landscape, trending on artstation"
+
+ generator = torch.Generator(device=torch_device).manual_seed(0)
+ output = pipe(
+ prompt=prompt,
+ init_image=init_image,
+ strength=0.75,
+ guidance_scale=7.5,
+ generator=generator,
+ output_type="np",
+ )
+ image = output.images[0]
+
+ assert image.shape == (512, 768, 3)
+ # img2img is flaky across GPUs even in fp32, so using MAE here
+ assert np.abs(expected_image - image).max() < 1e-3
diff --git a/tests/pipelines/dance_diffusion/test_dance_diffusion.py b/tests/pipelines/dance_diffusion/test_dance_diffusion.py
index 72e67e4479..a63ef84c63 100644
--- a/tests/pipelines/dance_diffusion/test_dance_diffusion.py
+++ b/tests/pipelines/dance_diffusion/test_dance_diffusion.py
@@ -44,6 +44,10 @@ class PipelineFastTests(unittest.TestCase):
sample_rate=16_000,
in_channels=2,
out_channels=2,
+ flip_sin_to_cos=True,
+ use_timestep_embedding=False,
+ time_embedding_type="fourier",
+ mid_block_type="UNetMidBlock1D",
down_block_types=["DownBlock1DNoSkip"] + ["DownBlock1D"] + ["AttnDownBlock1D"],
up_block_types=["AttnUpBlock1D"] + ["UpBlock1D"] + ["UpBlock1DNoSkip"],
)
diff --git a/tests/pipelines/ddim/test_ddim.py b/tests/pipelines/ddim/test_ddim.py
index 81c49912be..2d03383599 100644
--- a/tests/pipelines/ddim/test_ddim.py
+++ b/tests/pipelines/ddim/test_ddim.py
@@ -75,7 +75,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
model_id = "google/ddpm-ema-bedroom-256"
unet = UNet2DModel.from_pretrained(model_id)
- scheduler = DDIMScheduler.from_config(model_id)
+ scheduler = DDIMScheduler.from_pretrained(model_id)
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)
diff --git a/tests/pipelines/ddpm/test_ddpm.py b/tests/pipelines/ddpm/test_ddpm.py
index 14bc094697..ef293109bf 100644
--- a/tests/pipelines/ddpm/test_ddpm.py
+++ b/tests/pipelines/ddpm/test_ddpm.py
@@ -106,7 +106,7 @@ class DDPMPipelineIntegrationTests(unittest.TestCase):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id)
- scheduler = DDPMScheduler.from_config(model_id)
+ scheduler = DDPMScheduler.from_pretrained(model_id)
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)
diff --git a/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py b/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py
index f402d2f2a7..c04210dede 100644
--- a/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py
+++ b/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py
@@ -19,9 +19,8 @@ import unittest
import numpy as np
import torch
-import PIL
from diffusers import DDIMScheduler, LDMSuperResolutionPipeline, UNet2DModel, VQModel
-from diffusers.utils import floats_tensor, load_image, slow, torch_device
+from diffusers.utils import PIL_INTERPOLATION, floats_tensor, load_image, slow, torch_device
from diffusers.utils.testing_utils import require_torch
from ...test_pipelines_common import PipelineTesterMixin
@@ -97,7 +96,7 @@ class LDMSuperResolutionPipelineIntegrationTests(unittest.TestCase):
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/vq_diffusion/teddy_bear_pool.png"
)
- init_image = init_image.resize((64, 64), resample=PIL.Image.LANCZOS)
+ init_image = init_image.resize((64, 64), resample=PIL_INTERPOLATION["lanczos"])
ldm = LDMSuperResolutionPipeline.from_pretrained("duongna/ldm-super-resolution", device_map="auto")
ldm.to(torch_device)
diff --git a/tests/pipelines/repaint/test_repaint.py b/tests/pipelines/repaint/test_repaint.py
index 23544dfd24..3ab0efc875 100644
--- a/tests/pipelines/repaint/test_repaint.py
+++ b/tests/pipelines/repaint/test_repaint.py
@@ -44,7 +44,7 @@ class RepaintPipelineIntegrationTests(unittest.TestCase):
model_id = "google/ddpm-ema-celebahq-256"
unet = UNet2DModel.from_pretrained(model_id)
- scheduler = RePaintScheduler.from_config(model_id)
+ scheduler = RePaintScheduler.from_pretrained(model_id)
repaint = RePaintPipeline(unet=unet, scheduler=scheduler).to(torch_device)
diff --git a/tests/pipelines/score_sde_ve/test_score_sde_ve.py b/tests/pipelines/score_sde_ve/test_score_sde_ve.py
index 55dcc1cea1..9cdf3f0191 100644
--- a/tests/pipelines/score_sde_ve/test_score_sde_ve.py
+++ b/tests/pipelines/score_sde_ve/test_score_sde_ve.py
@@ -74,7 +74,7 @@ class ScoreSdeVePipelineIntegrationTests(unittest.TestCase):
model_id = "google/ncsnpp-church-256"
model = UNet2DModel.from_pretrained(model_id)
- scheduler = ScoreSdeVeScheduler.from_config(model_id)
+ scheduler = ScoreSdeVeScheduler.from_pretrained(model_id)
sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
sde_ve.to(torch_device)
diff --git a/tests/pipelines/stable_diffusion/test_cycle_diffusion.py b/tests/pipelines/stable_diffusion/test_cycle_diffusion.py
index de918c7e5c..7a32b74096 100644
--- a/tests/pipelines/stable_diffusion/test_cycle_diffusion.py
+++ b/tests/pipelines/stable_diffusion/test_cycle_diffusion.py
@@ -281,7 +281,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
init_image = init_image.resize((512, 512))
model_id = "CompVis/stable-diffusion-v1-4"
- scheduler = DDIMScheduler.from_config(model_id, subfolder="scheduler")
+ scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = CycleDiffusionPipeline.from_pretrained(
model_id, scheduler=scheduler, safety_checker=None, torch_dtype=torch.float16, revision="fp16"
)
@@ -322,7 +322,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
init_image = init_image.resize((512, 512))
model_id = "CompVis/stable-diffusion-v1-4"
- scheduler = DDIMScheduler.from_config(model_id, subfolder="scheduler")
+ scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = CycleDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, safety_checker=None)
pipe.to(torch_device)
diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py
index a1946e39f9..a2b48d27e6 100644
--- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py
+++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py
@@ -75,7 +75,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_inference_ddim(self):
- ddim_scheduler = DDIMScheduler.from_config(
+ ddim_scheduler = DDIMScheduler.from_pretrained(
"runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx"
)
sd_pipe = OnnxStableDiffusionPipeline.from_pretrained(
@@ -98,7 +98,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_inference_k_lms(self):
- lms_scheduler = LMSDiscreteScheduler.from_config(
+ lms_scheduler = LMSDiscreteScheduler.from_pretrained(
"runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx"
)
sd_pipe = OnnxStableDiffusionPipeline.from_pretrained(
diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py
index 61831c64c0..91e4412425 100644
--- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py
+++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py
@@ -93,7 +93,7 @@ class OnnxStableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
"/img2img/sketch-mountains-input.jpg"
)
init_image = init_image.resize((768, 512))
- lms_scheduler = LMSDiscreteScheduler.from_config(
+ lms_scheduler = LMSDiscreteScheduler.from_pretrained(
"runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx"
)
pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(
diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py
index 4ba8e273b4..507375bddb 100644
--- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py
+++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py
@@ -97,7 +97,7 @@ class OnnxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
)
- lms_scheduler = LMSDiscreteScheduler.from_config(
+ lms_scheduler = LMSDiscreteScheduler.from_pretrained(
"runwayml/stable-diffusion-inpainting", subfolder="scheduler", revision="onnx"
)
pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained(
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py
index 87d238c869..17a293e605 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py
@@ -703,7 +703,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_fast_ddim(self):
- scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-1", subfolder="scheduler")
+ scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-1", subfolder="scheduler")
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", scheduler=scheduler)
sd_pipe = sd_pipe.to(torch_device)
@@ -726,7 +726,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
model_id = "CompVis/stable-diffusion-v1-1"
pipe = StableDiffusionPipeline.from_pretrained(model_id).to(torch_device)
pipe.set_progress_bar_config(disable=None)
- scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
+ scheduler = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe.scheduler = scheduler
prompt = "a photograph of an astronaut riding a horse"
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
index 6d5c6feab5..d86b259eae 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
@@ -22,6 +22,7 @@ import torch
from diffusers import (
AutoencoderKL,
+ DDIMScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionImg2ImgPipeline,
@@ -479,7 +480,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
)
init_image = init_image.resize((768, 512))
expected_image = load_numpy(
- "https://huggingface.co/datasets/lewington/expected-images/resolve/main/fantasy_landscape.npy"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/fantasy_landscape.npy"
)
model_id = "CompVis/stable-diffusion-v1-4"
@@ -506,7 +507,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
assert image.shape == (512, 768, 3)
# img2img is flaky across GPUs even in fp32, so using MAE here
- assert np.abs(expected_image - image).mean() < 1e-3
+ assert np.abs(expected_image - image).max() < 1e-3
def test_stable_diffusion_img2img_pipeline_k_lms(self):
init_image = load_image(
@@ -515,11 +516,11 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
)
init_image = init_image.resize((768, 512))
expected_image = load_numpy(
- "https://huggingface.co/datasets/lewington/expected-images/resolve/main/fantasy_landscape_k_lms.npy"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/fantasy_landscape_k_lms.npy"
)
model_id = "CompVis/stable-diffusion-v1-4"
- lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
+ lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id,
scheduler=lms,
@@ -543,8 +544,44 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
image = output.images[0]
assert image.shape == (512, 768, 3)
- # img2img is flaky across GPUs even in fp32, so using MAE here
- assert np.abs(expected_image - image).mean() < 1e-3
+ assert np.abs(expected_image - image).max() < 1e-3
+
+ def test_stable_diffusion_img2img_pipeline_ddim(self):
+ init_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ "/img2img/sketch-mountains-input.jpg"
+ )
+ init_image = init_image.resize((768, 512))
+ expected_image = load_numpy(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/fantasy_landscape_ddim.npy"
+ )
+
+ model_id = "CompVis/stable-diffusion-v1-4"
+ ddim = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
+ model_id,
+ scheduler=ddim,
+ safety_checker=None,
+ )
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.enable_attention_slicing()
+
+ prompt = "A fantasy landscape, trending on artstation"
+
+ generator = torch.Generator(device=torch_device).manual_seed(0)
+ output = pipe(
+ prompt=prompt,
+ init_image=init_image,
+ strength=0.75,
+ guidance_scale=7.5,
+ generator=generator,
+ output_type="np",
+ )
+ image = output.images[0]
+
+ assert image.shape == (512, 768, 3)
+ assert np.abs(expected_image - image).max() < 1e-3
def test_stable_diffusion_img2img_intermediate_state(self):
number_of_steps = 0
@@ -612,7 +649,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
init_image = init_image.resize((768, 512))
model_id = "CompVis/stable-diffusion-v1-4"
- lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
+ lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id, scheduler=lms, safety_checker=None, device_map="auto", revision="fp16", torch_dtype=torch.float16
)
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
index 5fcdd71dd6..ce231a1a46 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
@@ -215,6 +215,47 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+ def test_stable_diffusion_inpaint_with_num_images_per_prompt(self):
+ device = "cpu"
+ unet = self.dummy_cond_unet_inpaint
+ scheduler = PNDMScheduler(skip_prk_steps=True)
+ vae = self.dummy_vae
+ bert = self.dummy_text_encoder
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
+ init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((128, 128))
+ mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
+
+ # make sure here that pndm scheduler skips prk
+ sd_pipe = StableDiffusionInpaintPipeline(
+ unet=unet,
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=bert,
+ tokenizer=tokenizer,
+ safety_checker=None,
+ feature_extractor=None,
+ )
+ sd_pipe = sd_pipe.to(device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ prompt = "A painting of a squirrel eating a burger"
+ generator = torch.Generator(device=device).manual_seed(0)
+ images = sd_pipe(
+ [prompt],
+ generator=generator,
+ guidance_scale=6.0,
+ num_inference_steps=2,
+ output_type="np",
+ image=init_image,
+ mask_image=mask_image,
+ num_images_per_prompt=2,
+ ).images
+
+ # check if the output is a list of 2 images
+ assert len(images) == 2
+
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
def test_stable_diffusion_inpaint_fp16(self):
"""Test that stable diffusion inpaint_legacy works with fp16"""
@@ -359,7 +400,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
)
model_id = "runwayml/stable-diffusion-inpainting"
- pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler")
+ pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None, scheduler=pndm)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -396,7 +437,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
)
model_id = "runwayml/stable-diffusion-inpainting"
- pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler")
+ pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id,
safety_checker=None,
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py
index c5b2572fb7..94106b6ba8 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py
@@ -387,7 +387,6 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
assert np.abs(expected_image - image).max() < 1e-3
def test_stable_diffusion_inpaint_legacy_pipeline_k_lms(self):
- # TODO(Anton, Patrick) - I think we can remove this test soon
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
@@ -402,7 +401,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
)
model_id = "CompVis/stable-diffusion-v1-4"
- lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler")
+ lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id,
scheduler=lms,
diff --git a/tests/test_config.py b/tests/test_config.py
index 8ae8e1d9e1..0875930e37 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -13,12 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
-import os
import tempfile
import unittest
-import diffusers
from diffusers import (
DDIMScheduler,
DDPMScheduler,
@@ -81,7 +78,7 @@ class SampleObject3(ConfigMixin):
class ConfigTester(unittest.TestCase):
def test_load_not_from_mixin(self):
with self.assertRaises(ValueError):
- ConfigMixin.from_config("dummy_path")
+ ConfigMixin.load_config("dummy_path")
def test_register_to_config(self):
obj = SampleObject()
@@ -131,7 +128,7 @@ class ConfigTester(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
obj.save_config(tmpdirname)
- new_obj = SampleObject.from_config(tmpdirname)
+ new_obj = SampleObject.from_config(SampleObject.load_config(tmpdirname))
new_config = new_obj.config
# unfreeze configs
@@ -142,117 +139,13 @@ class ConfigTester(unittest.TestCase):
assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json
assert config == new_config
- def test_save_load_from_different_config(self):
- obj = SampleObject()
-
- # mock add obj class to `diffusers`
- setattr(diffusers, "SampleObject", SampleObject)
- logger = logging.get_logger("diffusers.configuration_utils")
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- obj.save_config(tmpdirname)
- with CaptureLogger(logger) as cap_logger_1:
- new_obj_1 = SampleObject2.from_config(tmpdirname)
-
- # now save a config parameter that is not expected
- with open(os.path.join(tmpdirname, SampleObject.config_name), "r") as f:
- data = json.load(f)
- data["unexpected"] = True
-
- with open(os.path.join(tmpdirname, SampleObject.config_name), "w") as f:
- json.dump(data, f)
-
- with CaptureLogger(logger) as cap_logger_2:
- new_obj_2 = SampleObject.from_config(tmpdirname)
-
- with CaptureLogger(logger) as cap_logger_3:
- new_obj_3 = SampleObject2.from_config(tmpdirname)
-
- assert new_obj_1.__class__ == SampleObject2
- assert new_obj_2.__class__ == SampleObject
- assert new_obj_3.__class__ == SampleObject2
-
- assert cap_logger_1.out == ""
- assert (
- cap_logger_2.out
- == "The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will"
- " be ignored. Please verify your config.json configuration file.\n"
- )
- assert cap_logger_2.out.replace("SampleObject", "SampleObject2") == cap_logger_3.out
-
- def test_save_load_compatible_schedulers(self):
- SampleObject2._compatible_classes = ["SampleObject"]
- SampleObject._compatible_classes = ["SampleObject2"]
-
- obj = SampleObject()
-
- # mock add obj class to `diffusers`
- setattr(diffusers, "SampleObject", SampleObject)
- setattr(diffusers, "SampleObject2", SampleObject2)
- logger = logging.get_logger("diffusers.configuration_utils")
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- obj.save_config(tmpdirname)
-
- # now save a config parameter that is expected by another class, but not origin class
- with open(os.path.join(tmpdirname, SampleObject.config_name), "r") as f:
- data = json.load(f)
- data["f"] = [0, 0]
- data["unexpected"] = True
-
- with open(os.path.join(tmpdirname, SampleObject.config_name), "w") as f:
- json.dump(data, f)
-
- with CaptureLogger(logger) as cap_logger:
- new_obj = SampleObject.from_config(tmpdirname)
-
- assert new_obj.__class__ == SampleObject
-
- assert (
- cap_logger.out
- == "The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will"
- " be ignored. Please verify your config.json configuration file.\n"
- )
-
- def test_save_load_from_different_config_comp_schedulers(self):
- SampleObject3._compatible_classes = ["SampleObject", "SampleObject2"]
- SampleObject2._compatible_classes = ["SampleObject", "SampleObject3"]
- SampleObject._compatible_classes = ["SampleObject2", "SampleObject3"]
-
- obj = SampleObject()
-
- # mock add obj class to `diffusers`
- setattr(diffusers, "SampleObject", SampleObject)
- setattr(diffusers, "SampleObject2", SampleObject2)
- setattr(diffusers, "SampleObject3", SampleObject3)
- logger = logging.get_logger("diffusers.configuration_utils")
- logger.setLevel(diffusers.logging.INFO)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- obj.save_config(tmpdirname)
-
- with CaptureLogger(logger) as cap_logger_1:
- new_obj_1 = SampleObject.from_config(tmpdirname)
-
- with CaptureLogger(logger) as cap_logger_2:
- new_obj_2 = SampleObject2.from_config(tmpdirname)
-
- with CaptureLogger(logger) as cap_logger_3:
- new_obj_3 = SampleObject3.from_config(tmpdirname)
-
- assert new_obj_1.__class__ == SampleObject
- assert new_obj_2.__class__ == SampleObject2
- assert new_obj_3.__class__ == SampleObject3
-
- assert cap_logger_1.out == ""
- assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
- assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
-
def test_load_ddim_from_pndm(self):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
- ddim = DDIMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
+ ddim = DDIMScheduler.from_pretrained(
+ "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
+ )
assert ddim.__class__ == DDIMScheduler
# no warning should be thrown
@@ -262,7 +155,7 @@ class ConfigTester(unittest.TestCase):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
- euler = EulerDiscreteScheduler.from_config(
+ euler = EulerDiscreteScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
)
@@ -274,7 +167,7 @@ class ConfigTester(unittest.TestCase):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
- euler = EulerAncestralDiscreteScheduler.from_config(
+ euler = EulerAncestralDiscreteScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
)
@@ -286,7 +179,9 @@ class ConfigTester(unittest.TestCase):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
- pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
+ pndm = PNDMScheduler.from_pretrained(
+ "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
+ )
assert pndm.__class__ == PNDMScheduler
# no warning should be thrown
@@ -296,7 +191,7 @@ class ConfigTester(unittest.TestCase):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
- ddpm = DDPMScheduler.from_config(
+ ddpm = DDPMScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch",
subfolder="scheduler",
predict_epsilon=False,
@@ -304,7 +199,7 @@ class ConfigTester(unittest.TestCase):
)
with CaptureLogger(logger) as cap_logger_2:
- ddpm_2 = DDPMScheduler.from_config("google/ddpm-celebahq-256", beta_start=88)
+ ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88)
assert ddpm.__class__ == DDPMScheduler
assert ddpm.config.predict_epsilon is False
@@ -319,7 +214,7 @@ class ConfigTester(unittest.TestCase):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
- dpm = DPMSolverMultistepScheduler.from_config(
+ dpm = DPMSolverMultistepScheduler.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
)
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index eabe6ada9f..49bb4f6deb 100644
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -130,7 +130,7 @@ class ModelTesterMixin:
expected_arg_names = ["sample", "timestep"]
self.assertListEqual(arg_names[:2], expected_arg_names)
- def test_model_from_config(self):
+ def test_model_from_pretrained(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
@@ -140,8 +140,8 @@ class ModelTesterMixin:
# test if the model can be loaded from the config
# and has all the expected shape
with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_config(tmpdirname)
- new_model = self.model_class.from_config(tmpdirname)
+ model.save_pretrained(tmpdirname)
+ new_model = self.model_class.from_pretrained(tmpdirname)
new_model.to(torch_device)
new_model.eval()
diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py
index 4559d713ed..c77b000292 100644
--- a/tests/test_pipelines.py
+++ b/tests/test_pipelines.py
@@ -29,6 +29,10 @@ from diffusers import (
DDIMScheduler,
DDPMPipeline,
DDPMScheduler,
+ DPMSolverMultistepScheduler,
+ EulerAncestralDiscreteScheduler,
+ EulerDiscreteScheduler,
+ LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipelineLegacy,
@@ -398,6 +402,82 @@ class PipelineFastTests(unittest.TestCase):
assert image_img2img.shape == (1, 32, 32, 3)
assert image_text2img.shape == (1, 128, 128, 3)
+ def test_set_scheduler(self):
+ unet = self.dummy_cond_unet
+ scheduler = PNDMScheduler(skip_prk_steps=True)
+ vae = self.dummy_vae
+ bert = self.dummy_text_encoder
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ sd = StableDiffusionPipeline(
+ unet=unet,
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=bert,
+ tokenizer=tokenizer,
+ safety_checker=None,
+ feature_extractor=self.dummy_extractor,
+ )
+
+ sd.scheduler = DDIMScheduler.from_config(sd.scheduler.config)
+ assert isinstance(sd.scheduler, DDIMScheduler)
+ sd.scheduler = DDPMScheduler.from_config(sd.scheduler.config)
+ assert isinstance(sd.scheduler, DDPMScheduler)
+ sd.scheduler = PNDMScheduler.from_config(sd.scheduler.config)
+ assert isinstance(sd.scheduler, PNDMScheduler)
+ sd.scheduler = LMSDiscreteScheduler.from_config(sd.scheduler.config)
+ assert isinstance(sd.scheduler, LMSDiscreteScheduler)
+ sd.scheduler = EulerDiscreteScheduler.from_config(sd.scheduler.config)
+ assert isinstance(sd.scheduler, EulerDiscreteScheduler)
+ sd.scheduler = EulerAncestralDiscreteScheduler.from_config(sd.scheduler.config)
+ assert isinstance(sd.scheduler, EulerAncestralDiscreteScheduler)
+ sd.scheduler = DPMSolverMultistepScheduler.from_config(sd.scheduler.config)
+ assert isinstance(sd.scheduler, DPMSolverMultistepScheduler)
+
+ def test_set_scheduler_consistency(self):
+ unet = self.dummy_cond_unet
+ pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
+ ddim = DDIMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
+ vae = self.dummy_vae
+ bert = self.dummy_text_encoder
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ sd = StableDiffusionPipeline(
+ unet=unet,
+ scheduler=pndm,
+ vae=vae,
+ text_encoder=bert,
+ tokenizer=tokenizer,
+ safety_checker=None,
+ feature_extractor=self.dummy_extractor,
+ )
+
+ pndm_config = sd.scheduler.config
+ sd.scheduler = DDPMScheduler.from_config(pndm_config)
+ sd.scheduler = PNDMScheduler.from_config(sd.scheduler.config)
+ pndm_config_2 = sd.scheduler.config
+ pndm_config_2 = {k: v for k, v in pndm_config_2.items() if k in pndm_config}
+
+ assert dict(pndm_config) == dict(pndm_config_2)
+
+ sd = StableDiffusionPipeline(
+ unet=unet,
+ scheduler=ddim,
+ vae=vae,
+ text_encoder=bert,
+ tokenizer=tokenizer,
+ safety_checker=None,
+ feature_extractor=self.dummy_extractor,
+ )
+
+ ddim_config = sd.scheduler.config
+ sd.scheduler = LMSDiscreteScheduler.from_config(ddim_config)
+ sd.scheduler = DDIMScheduler.from_config(sd.scheduler.config)
+ ddim_config_2 = sd.scheduler.config
+ ddim_config_2 = {k: v for k, v in ddim_config_2.items() if k in ddim_config}
+
+ assert dict(ddim_config) == dict(ddim_config_2)
+
@slow
class PipelineSlowTests(unittest.TestCase):
@@ -519,7 +599,7 @@ class PipelineSlowTests(unittest.TestCase):
def test_output_format(self):
model_path = "google/ddpm-cifar10-32"
- scheduler = DDIMScheduler.from_config(model_path)
+ scheduler = DDIMScheduler.from_pretrained(model_path)
pipe = DDIMPipeline.from_pretrained(model_path, scheduler=scheduler)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py
index a9770f0a54..9c9abd0973 100755
--- a/tests/test_scheduler.py
+++ b/tests/test_scheduler.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
+import json
+import os
import tempfile
import unittest
from typing import Dict, List, Tuple
@@ -21,6 +23,7 @@ import numpy as np
import torch
import torch.nn.functional as F
+import diffusers
from diffusers import (
DDIMScheduler,
DDPMScheduler,
@@ -32,13 +35,180 @@ from diffusers import (
PNDMScheduler,
ScoreSdeVeScheduler,
VQDiffusionScheduler,
+ logging,
)
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils import deprecate, torch_device
+from diffusers.utils.testing_utils import CaptureLogger
torch.backends.cuda.matmul.allow_tf32 = False
+class SchedulerObject(SchedulerMixin, ConfigMixin):
+ config_name = "config.json"
+
+ @register_to_config
+ def __init__(
+ self,
+ a=2,
+ b=5,
+ c=(2, 5),
+ d="for diffusion",
+ e=[1, 3],
+ ):
+ pass
+
+
+class SchedulerObject2(SchedulerMixin, ConfigMixin):
+ config_name = "config.json"
+
+ @register_to_config
+ def __init__(
+ self,
+ a=2,
+ b=5,
+ c=(2, 5),
+ d="for diffusion",
+ f=[1, 3],
+ ):
+ pass
+
+
+class SchedulerObject3(SchedulerMixin, ConfigMixin):
+ config_name = "config.json"
+
+ @register_to_config
+ def __init__(
+ self,
+ a=2,
+ b=5,
+ c=(2, 5),
+ d="for diffusion",
+ e=[1, 3],
+ f=[1, 3],
+ ):
+ pass
+
+
+class SchedulerBaseTests(unittest.TestCase):
+ def test_save_load_from_different_config(self):
+ obj = SchedulerObject()
+
+ # mock add obj class to `diffusers`
+ setattr(diffusers, "SchedulerObject", SchedulerObject)
+ logger = logging.get_logger("diffusers.configuration_utils")
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ obj.save_config(tmpdirname)
+ with CaptureLogger(logger) as cap_logger_1:
+ config = SchedulerObject2.load_config(tmpdirname)
+ new_obj_1 = SchedulerObject2.from_config(config)
+
+ # now save a config parameter that is not expected
+ with open(os.path.join(tmpdirname, SchedulerObject.config_name), "r") as f:
+ data = json.load(f)
+ data["unexpected"] = True
+
+ with open(os.path.join(tmpdirname, SchedulerObject.config_name), "w") as f:
+ json.dump(data, f)
+
+ with CaptureLogger(logger) as cap_logger_2:
+ config = SchedulerObject.load_config(tmpdirname)
+ new_obj_2 = SchedulerObject.from_config(config)
+
+ with CaptureLogger(logger) as cap_logger_3:
+ config = SchedulerObject2.load_config(tmpdirname)
+ new_obj_3 = SchedulerObject2.from_config(config)
+
+ assert new_obj_1.__class__ == SchedulerObject2
+ assert new_obj_2.__class__ == SchedulerObject
+ assert new_obj_3.__class__ == SchedulerObject2
+
+ assert cap_logger_1.out == ""
+ assert (
+ cap_logger_2.out
+ == "The config attributes {'unexpected': True} were passed to SchedulerObject, but are not expected and"
+ " will"
+ " be ignored. Please verify your config.json configuration file.\n"
+ )
+ assert cap_logger_2.out.replace("SchedulerObject", "SchedulerObject2") == cap_logger_3.out
+
+ def test_save_load_compatible_schedulers(self):
+ SchedulerObject2._compatibles = ["SchedulerObject"]
+ SchedulerObject._compatibles = ["SchedulerObject2"]
+
+ obj = SchedulerObject()
+
+ # mock add obj class to `diffusers`
+ setattr(diffusers, "SchedulerObject", SchedulerObject)
+ setattr(diffusers, "SchedulerObject2", SchedulerObject2)
+ logger = logging.get_logger("diffusers.configuration_utils")
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ obj.save_config(tmpdirname)
+
+ # now save a config parameter that is expected by another class, but not origin class
+ with open(os.path.join(tmpdirname, SchedulerObject.config_name), "r") as f:
+ data = json.load(f)
+ data["f"] = [0, 0]
+ data["unexpected"] = True
+
+ with open(os.path.join(tmpdirname, SchedulerObject.config_name), "w") as f:
+ json.dump(data, f)
+
+ with CaptureLogger(logger) as cap_logger:
+ config = SchedulerObject.load_config(tmpdirname)
+ new_obj = SchedulerObject.from_config(config)
+
+ assert new_obj.__class__ == SchedulerObject
+
+ assert (
+ cap_logger.out
+ == "The config attributes {'unexpected': True} were passed to SchedulerObject, but are not expected and"
+ " will"
+ " be ignored. Please verify your config.json configuration file.\n"
+ )
+
+ def test_save_load_from_different_config_comp_schedulers(self):
+ SchedulerObject3._compatibles = ["SchedulerObject", "SchedulerObject2"]
+ SchedulerObject2._compatibles = ["SchedulerObject", "SchedulerObject3"]
+ SchedulerObject._compatibles = ["SchedulerObject2", "SchedulerObject3"]
+
+ obj = SchedulerObject()
+
+ # mock add obj class to `diffusers`
+ setattr(diffusers, "SchedulerObject", SchedulerObject)
+ setattr(diffusers, "SchedulerObject2", SchedulerObject2)
+ setattr(diffusers, "SchedulerObject3", SchedulerObject3)
+ logger = logging.get_logger("diffusers.configuration_utils")
+ logger.setLevel(diffusers.logging.INFO)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ obj.save_config(tmpdirname)
+
+ with CaptureLogger(logger) as cap_logger_1:
+ config = SchedulerObject.load_config(tmpdirname)
+ new_obj_1 = SchedulerObject.from_config(config)
+
+ with CaptureLogger(logger) as cap_logger_2:
+ config = SchedulerObject2.load_config(tmpdirname)
+ new_obj_2 = SchedulerObject2.from_config(config)
+
+ with CaptureLogger(logger) as cap_logger_3:
+ config = SchedulerObject3.load_config(tmpdirname)
+ new_obj_3 = SchedulerObject3.from_config(config)
+
+ assert new_obj_1.__class__ == SchedulerObject
+ assert new_obj_2.__class__ == SchedulerObject2
+ assert new_obj_3.__class__ == SchedulerObject3
+
+ assert cap_logger_1.out == ""
+ assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
+ assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n"
+
+
class SchedulerCommonTest(unittest.TestCase):
scheduler_classes = ()
forward_default_kwargs = ()
@@ -102,7 +272,7 @@ class SchedulerCommonTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
- new_scheduler = scheduler_class.from_config(tmpdirname)
+ new_scheduler = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
@@ -145,7 +315,7 @@ class SchedulerCommonTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
- new_scheduler = scheduler_class.from_config(tmpdirname)
+ new_scheduler = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
@@ -187,7 +357,7 @@ class SchedulerCommonTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
- new_scheduler = scheduler_class.from_config(tmpdirname)
+ new_scheduler = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
@@ -205,6 +375,42 @@ class SchedulerCommonTest(unittest.TestCase):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
+ def test_compatibles(self):
+ for scheduler_class in self.scheduler_classes:
+ scheduler_config = self.get_scheduler_config()
+
+ scheduler = scheduler_class(**scheduler_config)
+
+ assert all(c is not None for c in scheduler.compatibles)
+
+ for comp_scheduler_cls in scheduler.compatibles:
+ comp_scheduler = comp_scheduler_cls.from_config(scheduler.config)
+ assert comp_scheduler is not None
+
+ new_scheduler = scheduler_class.from_config(comp_scheduler.config)
+
+ new_scheduler_config = {k: v for k, v in new_scheduler.config.items() if k in scheduler.config}
+ scheduler_diff = {k: v for k, v in new_scheduler.config.items() if k not in scheduler.config}
+
+ # make sure that configs are essentially identical
+ assert new_scheduler_config == dict(scheduler.config)
+
+ # make sure that only differences are for configs that are not in init
+ init_keys = inspect.signature(scheduler_class.__init__).parameters.keys()
+ assert set(scheduler_diff.keys()).intersection(set(init_keys)) == set()
+
+ def test_from_pretrained(self):
+ for scheduler_class in self.scheduler_classes:
+ scheduler_config = self.get_scheduler_config()
+
+ scheduler = scheduler_class(**scheduler_config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ scheduler.save_pretrained(tmpdirname)
+ new_scheduler = scheduler_class.from_pretrained(tmpdirname)
+
+ assert scheduler.config == new_scheduler.config
+
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
@@ -616,7 +822,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
- new_scheduler = scheduler_class.from_config(tmpdirname)
+ new_scheduler = scheduler_class.from_pretrained(tmpdirname)
new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
@@ -648,7 +854,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
- new_scheduler = scheduler_class.from_config(tmpdirname)
+ new_scheduler = scheduler_class.from_pretrained(tmpdirname)
# copy over dummy past residuals
new_scheduler.set_timesteps(num_inference_steps)
@@ -790,7 +996,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
- new_scheduler = scheduler_class.from_config(tmpdirname)
+ new_scheduler = scheduler_class.from_pretrained(tmpdirname)
new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:]
@@ -825,7 +1031,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
- new_scheduler = scheduler_class.from_config(tmpdirname)
+ new_scheduler = scheduler_class.from_pretrained(tmpdirname)
# copy over dummy past residuals
new_scheduler.set_timesteps(num_inference_steps)
@@ -1043,7 +1249,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
- new_scheduler = scheduler_class.from_config(tmpdirname)
+ new_scheduler = scheduler_class.from_pretrained(tmpdirname)
output = scheduler.step_pred(
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
@@ -1074,7 +1280,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
- new_scheduler = scheduler_class.from_config(tmpdirname)
+ new_scheduler = scheduler_class.from_pretrained(tmpdirname)
output = scheduler.step_pred(
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
@@ -1470,7 +1676,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
- new_scheduler = scheduler_class.from_config(tmpdirname)
+ new_scheduler = scheduler_class.from_pretrained(tmpdirname)
new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:]
@@ -1508,7 +1714,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
- new_scheduler = scheduler_class.from_config(tmpdirname)
+ new_scheduler = scheduler_class.from_pretrained(tmpdirname)
# copy over dummy past residuals
new_scheduler.set_timesteps(num_inference_steps)
diff --git a/tests/test_scheduler_flax.py b/tests/test_scheduler_flax.py
index 7928939f2d..0fa0e1b495 100644
--- a/tests/test_scheduler_flax.py
+++ b/tests/test_scheduler_flax.py
@@ -83,7 +83,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
- new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
+ new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
@@ -112,7 +112,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
- new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
+ new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
@@ -140,7 +140,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
- new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
+ new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
@@ -373,7 +373,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
- new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
+ new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
@@ -401,7 +401,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
- new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
+ new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
@@ -430,7 +430,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
- new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
+ new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
state = scheduler.set_timesteps(state, num_inference_steps)
@@ -633,7 +633,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
- new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
+ new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)
# copy over dummy past residuals
new_state = new_state.replace(ets=dummy_past_residuals[:])
@@ -720,7 +720,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
- new_scheduler, new_state = scheduler_class.from_config(tmpdirname)
+ new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
# copy over dummy past residuals
new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)