1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Merge remote-tracking branch 'origin' into 1d_blocks

This commit is contained in:
Nathan Lambert
2022-11-29 15:12:56 -08:00
182 changed files with 19006 additions and 1257 deletions

View File

@@ -9,11 +9,8 @@ concurrency:
jobs:
build:
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@use_hf_hub
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
with:
commit_sha: ${{ github.event.pull_request.head.sha }}
pr_number: ${{ github.event.number }}
package: diffusers
secrets:
token: ${{ secrets.HF_DOC_PUSH }}
comment_bot_token: ${{ secrets.HUGGINGFACE_PUSH }}

View File

@@ -7,10 +7,7 @@ on:
jobs:
delete:
uses: huggingface/doc-builder/.github/workflows/delete_doc_comment.yml@use_hf_hub
uses: huggingface/doc-builder/.github/workflows/delete_doc_comment.yml@main
with:
pr_number: ${{ github.event.number }}
package: diffusers
secrets:
token: ${{ secrets.HF_DOC_PUSH }}
comment_bot_token: ${{ secrets.HUGGINGFACE_PUSH }}

View File

@@ -60,6 +60,7 @@ jobs:
run: |
python -m pip install -e .[quality,test]
python -m pip install git+https://github.com/huggingface/accelerate
python -m pip install -U git+https://github.com/huggingface/transformers
- name: Environment
run: |
@@ -127,6 +128,7 @@ jobs:
${CONDA_RUN} python -m pip install -e .[quality,test]
${CONDA_RUN} python -m pip install --pre torch==${MPS_TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/test/cpu
${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate
${CONDA_RUN} python -m pip install -U git+https://github.com/huggingface/transformers
- name: Environment
shell: arch -arch arm64 bash {0}

View File

@@ -62,6 +62,7 @@ jobs:
run: |
python -m pip install -e .[quality,test]
python -m pip install git+https://github.com/huggingface/accelerate
python -m pip install -U git+https://github.com/huggingface/transformers
- name: Environment
run: |
@@ -131,6 +132,7 @@ jobs:
run: |
python -m pip install -e .[quality,test,training]
python -m pip install git+https://github.com/huggingface/accelerate
python -m pip install -U git+https://github.com/huggingface/transformers
- name: Environment
run: |

View File

@@ -152,15 +152,7 @@ it before the pipeline and pass it to `from_pretrained`.
```python
from diffusers import LMSDiscreteScheduler
lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
revision="fp16",
torch_dtype=torch.float16,
scheduler=lms,
)
pipe = pipe.to("cuda")
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
@@ -353,7 +345,8 @@ Textual Inversion is a technique for capturing novel concepts from a small numbe
## Stable Diffusion Community Pipelines
The release of Stable Diffusion as an open source model has fostered a lot of interesting ideas and experimentation. Our [Community Examples folder](https://github.com/huggingface/diffusers/tree/main/examples/community) contains many ideas worth exploring, like interpolating to create animated videos, using CLIP Guidance for additional prompt fidelity, term weighting, and much more! [Take a look](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview) and [contribute your own](https://huggingface.co/docs/diffusers/using-diffusers/contribute_pipeline).
The release of Stable Diffusion as an open source model has fostered a lot of interesting ideas and experimentation.
Our [Community Examples folder](https://github.com/huggingface/diffusers/tree/main/examples/community) contains many ideas worth exploring, like interpolating to create animated videos, using CLIP Guidance for additional prompt fidelity, term weighting, and much more! [Take a look](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview) and [contribute your own](https://huggingface.co/docs/diffusers/using-diffusers/contribute_pipeline).
## Other Examples
@@ -402,10 +395,14 @@ image.save("ddpm_generated_image.png")
- [Unconditional Latent Diffusion](https://huggingface.co/CompVis/ldm-celebahq-256)
- [Unconditional Diffusion with continuous scheduler](https://huggingface.co/google/ncsnpp-ffhq-1024)
**Other Notebooks**:
**Other Image Notebooks**:
* [image-to-image generation with Stable Diffusion](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg),
* [tweak images via repeated Stable Diffusion seeds](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg),
**Diffusers for Other Modalities**:
* [Molecule conformation generation](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/geodiff_molecule_conformation.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg),
* [Model-based reinforcement learning](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_with_diffusers.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg),
### Web Demos
If you just want to play around with some web demos, you can try out the following 🚀 Spaces:
| Model | Hugging Face Spaces |

View File

@@ -10,6 +10,8 @@
- sections:
- local: using-diffusers/loading
title: "Loading Pipelines, Models, and Schedulers"
- local: using-diffusers/schedulers
title: "Using different Schedulers"
- local: using-diffusers/configuration
title: "Configuring Pipelines, Models, and Schedulers"
- local: using-diffusers/custom_pipeline_overview
@@ -29,6 +31,14 @@
- local: using-diffusers/contribute_pipeline
title: "How to contribute a Pipeline"
title: "Pipelines for Inference"
- sections:
- local: using-diffusers/rl
title: "Reinforcement Learning"
- local: using-diffusers/audio
title: "Audio"
- local: using-diffusers/other-modalities
title: "Other Modalities"
title: "Taking Diffusers Beyond Images"
title: "Using Diffusers"
- sections:
- local: optimization/fp16
@@ -78,6 +88,8 @@
- sections:
- local: api/pipelines/overview
title: "Overview"
- local: api/pipelines/alt_diffusion
title: "AltDiffusion"
- local: api/pipelines/cycle_diffusion
title: "Cycle Diffusion"
- local: api/pipelines/ddim
@@ -94,13 +106,23 @@
title: "Score SDE VE"
- local: api/pipelines/stable_diffusion
title: "Stable Diffusion"
- local: api/pipelines/stable_diffusion_2
title: "Stable Diffusion 2"
- local: api/pipelines/stable_diffusion_safe
title: "Safe Stable Diffusion"
- local: api/pipelines/stochastic_karras_ve
title: "Stochastic Karras VE"
- local: api/pipelines/dance_diffusion
title: "Dance Diffusion"
- local: api/pipelines/versatile_diffusion
title: "Versatile Diffusion"
- local: api/pipelines/vq_diffusion
title: "VQ Diffusion"
- local: api/pipelines/repaint
title: "RePaint"
title: "Pipelines"
- sections:
- local: api/experimental/rl
title: "RL Planning"
title: "Experimental Features"
title: "API"

View File

@@ -15,9 +15,9 @@ specific language governing permissions and limitations under the License.
In Diffusers, schedulers of type [`schedulers.scheduling_utils.SchedulerMixin`], and models of type [`ModelMixin`] inherit from [`ConfigMixin`] which conveniently takes care of storing all parameters that are
passed to the respective `__init__` methods in a JSON-configuration file.
TODO(PVP) - add example and better info here
## ConfigMixin
[[autodoc]] ConfigMixin
- load_config
- from_config
- save_config

View File

@@ -0,0 +1,15 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# TODO
Coming soon!

View File

@@ -0,0 +1,83 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# AltDiffusion
AltDiffusion was proposed in [AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities](https://arxiv.org/abs/2211.06679) by Zhongzhi Chen, Guang Liu, Bo-Wen Zhang, Fulong Ye, Qinghong Yang, Ledell Wu
The abstract of the paper is the following:
*In this work, we present a conceptually simple and effective method to train a strong bilingual multimodal representation model. Starting from the pretrained multimodal representation model CLIP released by OpenAI, we switched its text encoder with a pretrained multilingual text encoder XLM-R, and aligned both languages and image representations by a two-stage training schema consisting of teacher learning and contrastive learning. We validate our method through evaluations of a wide range of tasks. We set new state-of-the-art performances on a bunch of tasks including ImageNet-CN, Flicker30k- CN, and COCO-CN. Further, we obtain very close performances with CLIP on almost all tasks, suggesting that one can simply alter the text encoder in CLIP for extended capabilities such as multilingual understanding.*
*Overview*:
| Pipeline | Tasks | Colab | Demo
|---|---|:---:|:---:|
| [pipeline_alt_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py) | *Text-to-Image Generation* | - | -
| [pipeline_alt_diffusion_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py) | *Image-to-Image Text-Guided Generation* | - |-
## Tips
- AltDiffusion is conceptually exaclty the same as [Stable Diffusion](./api/pipelines/stable_diffusion).
- *Run AltDiffusion*
AltDiffusion can be tested very easily with the [`AltDiffusionPipeline`], [`AltDiffusionImg2ImgPipeline`] and the `"BAAI/AltDiffusion-m9"` checkpoint exactly in the same way it is shown in the [Conditional Image Generation Guide](./using-diffusers/conditional_image_generation) and the [Image-to-Image Generation Guide](./using-diffusers/img2img).
- *How to load and use different schedulers.*
The alt diffusion pipeline uses [`DDIMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the alt diffusion pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
```python
>>> from diffusers import AltDiffusionPipeline, EulerDiscreteScheduler
>>> pipeline = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion-m9")
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
>>> # or
>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("BAAI/AltDiffusion-m9", subfolder="scheduler")
>>> pipeline = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion-m9", scheduler=euler_scheduler)
```
- *How to convert all use cases with multiple or single pipeline*
If you want to use all possible use cases in a single `DiffusionPipeline` we recommend using the `components` functionality to instantiate all components in the most memory-efficient way:
```python
>>> from diffusers import (
... AltDiffusionPipeline,
... AltDiffusionImg2ImgPipeline,
... )
>>> text2img = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion-m9")
>>> img2img = AltDiffusionImg2ImgPipeline(**text2img.components)
>>> # now you can use text2img(...) and img2img(...) just like the call methods of each respective pipeline
```
## AltDiffusionPipelineOutput
[[autodoc]] pipelines.alt_diffusion.AltDiffusionPipelineOutput
## AltDiffusionPipeline
[[autodoc]] AltDiffusionPipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing
## AltDiffusionImg2ImgPipeline
[[autodoc]] AltDiffusionImg2ImgPipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing

View File

@@ -39,7 +39,7 @@ from diffusers import CycleDiffusionPipeline, DDIMScheduler
# load the pipeline
# make sure you're logged in with `huggingface-cli login`
model_id_or_path = "CompVis/stable-diffusion-v1-4"
scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler")
scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler")
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")
# let's download an initial image

View File

@@ -39,9 +39,9 @@ The original codebase can be found [here](https://github.com/CompVis/latent-diff
## LDMTextToImagePipeline
[[autodoc]] pipelines.latent_diffusion.pipeline_latent_diffusion.LDMTextToImagePipeline
[[autodoc]] LDMTextToImagePipeline
- __call__
## LDMSuperResolutionPipeline
[[autodoc]] pipelines.latent_diffusion.pipeline_latent_diffusion_superresolution.LDMSuperResolutionPipeline
[[autodoc]] LDMSuperResolutionPipeline
- __call__

View File

@@ -44,11 +44,13 @@ available a colab notebook to directly try them out.
| Pipeline | Paper | Tasks | Colab
|---|---|:---:|:---:|
| [alt_diffusion](./api/pipelines/alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | -
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image |
| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
| [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
| [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
@@ -56,7 +58,14 @@ available a colab notebook to directly try them out.
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-to-Image Generation |
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Image Inpainting |
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image |
| [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb)
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation |
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation |
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation |
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |

View File

@@ -54,7 +54,7 @@ original_image = download_image(img_url).resize((256, 256))
mask_image = download_image(mask_url).resize((256, 256))
# Load the RePaint scheduler and pipeline based on a pretrained DDPM model
scheduler = RePaintScheduler.from_config("google/ddpm-ema-celebahq-256")
scheduler = RePaintScheduler.from_pretrained("google/ddpm-ema-celebahq-256")
pipe = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=scheduler)
pipe = pipe.to("cuda")

View File

@@ -34,17 +34,21 @@ For more details about how Stable Diffusion works and how it differs from the ba
### How to load and use different schedulers.
The stable diffusion pipeline uses [`PNDMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
To use a different scheduler, you can pass the `scheduler` argument to `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
```python
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
>>> from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
euler_scheduler = EulerDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=euler_scheduler)
>>> pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
>>> # or
>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
>>> pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=euler_scheduler)
```
### How to conver all use cases with multiple or single pipeline
### How to convert all use cases with multiple or single pipeline
If you want to use all possible use cases in a single `DiffusionPipeline` you can either:
- Make use of the [Stable Diffusion Mega Pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community#stable-diffusion-mega) or
@@ -57,11 +61,11 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca
... StableDiffusionInpaintPipeline,
... )
>>> img2text = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
>>> img2img = StableDiffusionImg2ImgPipeline(**img2text.components)
>>> inpaint = StableDiffusionInpaintPipeline(**img2text.components)
>>> text2img = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
>>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components)
>>> inpaint = StableDiffusionInpaintPipeline(**text2img.components)
>>> # now you can use img2text(...), img2img(...), inpaint(...) just like the call methods of each respective pipeline
>>> # now you can use text2img(...), img2img(...), inpaint(...) just like the call methods of each respective pipeline
```
## StableDiffusionPipelineOutput
@@ -72,6 +76,8 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca
- __call__
- enable_attention_slicing
- disable_attention_slicing
- enable_vae_slicing
- disable_vae_slicing
## StableDiffusionImg2ImgPipeline
[[autodoc]] StableDiffusionImg2ImgPipeline
@@ -84,3 +90,17 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca
- __call__
- enable_attention_slicing
- disable_attention_slicing
## StableDiffusionImageVariationPipeline
[[autodoc]] StableDiffusionImageVariationPipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing
## StableDiffusionUpscalePipeline
[[autodoc]] StableDiffusionUpscalePipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing

View File

@@ -0,0 +1,142 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Stable diffusion 2
Stable Diffusion 2 is a text-to-image _latent diffusion_ model built upon the work of [Stable Diffusion 1](https://stability.ai/blog/stable-diffusion-public-release).
The project to train Stable Diffusion 2 was led by Robin Rombach and Katherine Crowson from [Stability AI](https://stability.ai/) and [LAION](https://laion.ai/).
*The Stable Diffusion 2.0 release includes robust text-to-image models trained using a brand new text encoder (OpenCLIP), developed by LAION with support from Stability AI, which greatly improves the quality of the generated images compared to earlier V1 releases. The text-to-image models in this release can generate images with default resolutions of both 512x512 pixels and 768x768 pixels.
These models are trained on an aesthetic subset of the [LAION-5B dataset](https://laion.ai/blog/laion-5b/) created by the DeepFloyd team at Stability AI, which is then further filtered to remove adult content using [LAIONs NSFW filter](https://openreview.net/forum?id=M3Y74vmsMcY).*
For more details about how Stable Diffusion 2 works and how it differs from Stable Diffusion 1, please refer to the official [launch announcement post](https://stability.ai/blog/stable-diffusion-v2-release).
## Tips
### Available checkpoints:
Note that the architecture is more or less identical to [Stable Diffusion 1](./api/pipelines/stable_diffusion) so please refer to [this page](./api/pipelines/stable_diffusion) for API documentation.
- *Text-to-Image (512x512 resolution)*: [stabilityai/stable-diffusion-2-base](https://huggingface.co/stabilityai/stable-diffusion-2-base) with [`StableDiffusionPipeline`]
- *Text-to-Image (768x768 resolution)*: [stabilityai/stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) with [`StableDiffusionPipeline`]
- *Image Inpainting (512x512 resolution)*: [stabilityai/stable-diffusion-2-inpainting](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) with [`StableDiffusionInpaintPipeline`]
- *Image Upscaling (x4 resolution resolution)*: [stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) [`StableDiffusionUpscalePipeline`]
We recommend using the [`DPMSolverMultistepScheduler`] as it's currently the fastest scheduler there is.
- *Text-to-Image (512x512 resolution)*:
```python
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
import torch
repo_id = "stabilityai/stable-diffusion-2-base"
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
prompt = "High quality photo of an astronaut riding a horse in space"
image = pipe(prompt, num_inference_steps=25).images[0]
image.save("astronaut.png")
```
- *Text-to-Image (768x768 resolution)*:
```python
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
import torch
repo_id = "stabilityai/stable-diffusion-2"
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
prompt = "High quality photo of an astronaut riding a horse in space"
image = pipe(prompt, guidance_scale=9, num_inference_steps=25).images[0]
image.save("astronaut.png")
```
- *Image Inpainting (512x512 resolution)*:
```python
import PIL
import requests
import torch
from io import BytesIO
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
def download_image(url):
response = requests.get(url)
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
init_image = download_image(img_url).resize((512, 512))
mask_image = download_image(mask_url).resize((512, 512))
repo_id = "stabilityai/stable-diffusion-2-inpainting"
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
image = pipe(prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=25).images[0]
image.save("yellow_cat.png")
```
- *Image Upscaling (x4 resolution resolution)*: [stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) [`StableDiffusionUpscalePipeline`]
```python
import requests
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionUpscalePipeline
import torch
# load model and scheduler
model_id = "stabilityai/stable-diffusion-x4-upscaler"
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
pipeline = pipeline.to("cuda")
# let's download an image
url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png"
response = requests.get(url)
low_res_img = Image.open(BytesIO(response.content)).convert("RGB")
low_res_img = low_res_img.resize((128, 128))
prompt = "a white cat"
upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0]
upscaled_image.save("upsampled_cat.png")
```
### How to load and use different schedulers.
The stable diffusion pipeline uses [`DDIMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
```python
>>> from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
>>> pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2")
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
>>> # or
>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler")
>>> pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", scheduler=euler_scheduler)
```

View File

@@ -0,0 +1,90 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Safe Stable Diffusion
Safe Stable Diffusion was proposed in [Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models](https://arxiv.org/abs/2211.05105) and mitigates the well known issue that models like Stable Diffusion that are trained on unfiltered, web-crawled datasets tend to suffer from inappropriate degeneration. For instance Stable Diffusion may unexpectedly generate nudity, violence, images depicting self-harm, or otherwise offensive content.
Safe Stable Diffusion is an extension to the Stable Diffusion that drastically reduces content like this.
The abstract of the paper is the following:
*Text-conditioned image generation models have recently achieved astonishing results in image quality and text alignment and are consequently employed in a fast-growing number of applications. Since they are highly data-driven, relying on billion-sized datasets randomly scraped from the internet, they also suffer, as we demonstrate, from degenerated and biased human behavior. In turn, they may even reinforce such biases. To help combat these undesired side effects, we present safe latent diffusion (SLD). Specifically, to measure the inappropriate degeneration due to unfiltered and imbalanced training sets, we establish a novel image generation test bed-inappropriate image prompts (I2P)-containing dedicated, real-world image-to-text prompts covering concepts such as nudity and violence. As our exhaustive empirical evaluation demonstrates, the introduced SLD removes and suppresses inappropriate image parts during the diffusion process, with no additional training required and no adverse effect on overall image quality or text alignment.*
*Overview*:
| Pipeline | Tasks | Colab | Demo
|---|---|:---:|:---:|
| [pipeline_stable_diffusion_safe.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb) | -
## Tips
- Safe Stable Diffusion may also be used with weights of [Stable Diffusion](./api/pipelines/stable_diffusion).
### Run Safe Stable Diffusion
Safe Stable Diffusion can be tested very easily with the [`StableDiffusionPipelineSafe`], and the `"AIML-TUDA/stable-diffusion-safe"` checkpoint exactly in the same way it is shown in the [Conditional Image Generation Guide](./using-diffusers/conditional_image_generation).
### Interacting with the Safety Concept
To check and edit the currently used safety concept, use the `safety_concept` property of [`StableDiffusionPipelineSafe`]
```python
>>> from diffusers import StableDiffusionPipelineSafe
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe")
>>> pipeline.safety_concept
```
For each image generation the active concept is also contained in [`StableDiffusionSafePipelineOutput`].
### Using pre-defined safety configurations
You may use the 4 configurations defined in the [Safe Latent Diffusion paper](https://arxiv.org/abs/2211.05105) as follows:
```python
>>> from diffusers import StableDiffusionPipelineSafe
>>> from diffusers.pipelines.stable_diffusion_safe import SafetyConfig
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe")
>>> prompt = "the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c. leyendecker"
>>> out = pipeline(prompt=prompt, **SafetyConfig.MAX)
```
The following configurations are available: `SafetyConfig.WEAK`, `SafetyConfig.MEDIUM`, `SafetyConfig.STRONg`, and `SafetyConfig.MAX`.
### How to load and use different schedulers.
The safe stable diffusion pipeline uses [`PNDMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
```python
>>> from diffusers import StableDiffusionPipelineSafe, EulerDiscreteScheduler
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe")
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
>>> # or
>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("AIML-TUDA/stable-diffusion-safe", subfolder="scheduler")
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained(
... "AIML-TUDA/stable-diffusion-safe", scheduler=euler_scheduler
... )
```
## StableDiffusionSafePipelineOutput
[[autodoc]] pipelines.stable_diffusion_safe.StableDiffusionSafePipelineOutput
## StableDiffusionPipelineSafe
[[autodoc]] StableDiffusionPipelineSafe
- __call__
- enable_attention_slicing
- disable_attention_slicing

View File

@@ -0,0 +1,73 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# VersatileDiffusion
VersatileDiffusion was proposed in [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) by Xingqian Xu, Zhangyang Wang, Eric Zhang, Kai Wang, Humphrey Shi .
The abstract of the paper is the following:
*The recent advances in diffusion models have set an impressive milestone in many generation tasks. Trending works such as DALL-E2, Imagen, and Stable Diffusion have attracted great interest in academia and industry. Despite the rapid landscape changes, recent new approaches focus on extensions and performance rather than capacity, thus requiring separate models for separate tasks. In this work, we expand the existing single-flow diffusion pipeline into a multi-flow network, dubbed Versatile Diffusion (VD), that handles text-to-image, image-to-text, image-variation, and text-variation in one unified model. Moreover, we generalize VD to a unified multi-flow multimodal diffusion framework with grouped layers, swappable streams, and other propositions that can process modalities beyond images and text. Through our experiments, we demonstrate that VD and its underlying framework have the following merits: a) VD handles all subtasks with competitive quality; b) VD initiates novel extensions and applications such as disentanglement of style and semantic, image-text dual-guided generation, etc.; c) Through these experiments and applications, VD provides more semantic insights of the generated outputs.*
## Tips
- VersatileDiffusion is conceptually very similar as [Stable Diffusion](./api/pipelines/stable_diffusion), but instead of providing just a image data stream conditioned on text, VersatileDiffusion provides both a image and text data stream and can be conditioned on both text and image.
### *Run VersatileDiffusion*
You can both load the memory intensive "all-in-one" [`VersatileDiffusionPipeline`] that can run all tasks
with the same class as shown in [`VersatileDiffusionPipeline.text_to_image`], [`VersatileDiffusionPipeline.image_variation`], and [`VersatileDiffusionPipeline.dual_guided`]
**or**
You can run the individual pipelines which are much more memory efficient:
- *Text-to-Image*: [`VersatileDiffusionTextToImagePipeline.__call__`]
- *Image Variation*: [`VersatileDiffusionImageVariationPipeline.__call__`]
- *Dual Text and Image Guided Generation*: [`VersatileDiffusionDualGuidedPipeline.__call__`]
### *How to load and use different schedulers.*
The versatile diffusion pipelines uses [`DDIMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the alt diffusion pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
```python
>>> from diffusers import VersatileDiffusionPipeline, EulerDiscreteScheduler
>>> pipeline = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion")
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
>>> # or
>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("shi-labs/versatile-diffusion", subfolder="scheduler")
>>> pipeline = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", scheduler=euler_scheduler)
```
## VersatileDiffusionPipeline
[[autodoc]] VersatileDiffusionPipeline
## VersatileDiffusionTextToImagePipeline
[[autodoc]] VersatileDiffusionTextToImagePipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing
## VersatileDiffusionImageVariationPipeline
[[autodoc]] VersatileDiffusionImageVariationPipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing
## VersatileDiffusionDualGuidedPipeline
[[autodoc]] VersatileDiffusionDualGuidedPipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing

View File

@@ -34,11 +34,13 @@ available a colab notebook to directly try them out.
| Pipeline | Paper | Tasks | Colab
|---|---|:---:|:---:|
| [alt_diffusion](./api/pipelines/alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation |
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image |
| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
| [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
| [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
@@ -46,7 +48,14 @@ available a colab notebook to directly try them out.
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-to-Image Generation |
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Image Inpainting |
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image |
| [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb)
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation |
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation |
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation |
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.

View File

@@ -117,6 +117,34 @@ image = pipe(prompt).images[0]
There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM!
## Sliced VAE decode for larger batches
To decode large batches of images with limited VRAM, or to enable batches with 32 images or more, you can use sliced VAE decode that decodes the batch latents one image at a time.
You likely want to couple this with [`~StableDiffusionPipeline.enable_attention_slicing`] or [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`] to further minimize memory use.
To perform the VAE decode one image at a time, invoke [`~StableDiffusionPipeline.enable_vae_slicing`] in your pipeline before inference. For example:
```Python
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
revision="fp16",
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_vae_slicing()
images = pipe([prompt] * 32).images
```
You may see a small performance boost in VAE decode on multi-image batches. There should be no performance impact on single-image batches.
## Offloading to CPU with accelerate for memory savings
For additional memory savings, you can offload the weights to CPU and load them to GPU when performing the forward pass.

View File

@@ -41,7 +41,7 @@ In this guide though, you'll use [`DiffusionPipeline`] for text-to-image generat
```python
>>> from diffusers import DiffusionPipeline
>>> generator = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
>>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
```
The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components.
@@ -49,13 +49,13 @@ Because the model consists of roughly 1.4 billion parameters, we strongly recomm
You can move the generator object to GPU, just like you would in PyTorch.
```python
>>> generator.to("cuda")
>>> pipeline.to("cuda")
```
Now you can use the `generator` on your text prompt:
Now you can use the `pipeline` on your text prompt:
```python
>>> image = generator("An image of a squirrel in Picasso style").images[0]
>>> image = pipeline("An image of a squirrel in Picasso style").images[0]
```
The output is by default wrapped into a [PIL Image object](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class).
@@ -82,7 +82,7 @@ just like we did before only that now you need to pass your `AUTH_TOKEN`:
```python
>>> from diffusers import DiffusionPipeline
>>> generator = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN)
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN)
```
If you do not pass your authentication token you will see that the diffusion system will not be correctly
@@ -102,7 +102,7 @@ token. Assuming that `"./stable-diffusion-v1-5"` is the local path to the cloned
you can also load the pipeline as follows:
```python
>>> generator = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5")
>>> pipeline = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5")
```
Running the pipeline is then identical to the code above as it's the same model architecture.
@@ -115,19 +115,20 @@ Running the pipeline is then identical to the code above as it's the same model
Diffusion systems can be used with multiple different [schedulers](./api/schedulers) each with their
pros and cons. By default, Stable Diffusion runs with [`PNDMScheduler`], but it's very simple to
use a different scheduler. *E.g.* if you would instead like to use the [`LMSDiscreteScheduler`] scheduler,
use a different scheduler. *E.g.* if you would instead like to use the [`EulerDiscreteScheduler`] scheduler,
you could use it as follows:
```python
>>> from diffusers import LMSDiscreteScheduler
>>> from diffusers import EulerDiscreteScheduler
>>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
>>> pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN)
>>> generator = StableDiffusionPipeline.from_pretrained(
... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, use_auth_token=AUTH_TOKEN
... )
>>> # change scheduler to Euler
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
```
For more in-detail information on how to change between schedulers, please refer to the [Using Schedulers](./using-diffusers/schedulers) guide.
[Stability AI's](https://stability.ai/) Stable Diffusion model is an impressive image generation model
and can do much more than just generating images from text. We have dedicated a whole documentation page,
just for Stable Diffusion [here](./conceptual/stable_diffusion).

View File

@@ -0,0 +1,16 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Using Diffusers for audio
The [`DanceDiffusionPipeline`] can be used to generate audio rapidly!
More coming soon!

View File

@@ -44,5 +44,3 @@ You can save the image by simply calling:
```python
>>> image.save("image_of_squirrel_painting.png")
```

View File

@@ -19,7 +19,7 @@ In the following we explain in-detail how to easily load:
- *Complete Diffusion Pipelines* via the [`DiffusionPipeline.from_pretrained`]
- *Diffusion Models* via [`ModelMixin.from_pretrained`]
- *Schedulers* via [`ConfigMixin.from_config`]
- *Schedulers* via [`SchedulerMixin.from_pretrained`]
## Loading pipelines
@@ -137,15 +137,15 @@ from diffusers import DiffusionPipeline, EulerDiscreteScheduler, DPMSolverMultis
repo_id = "runwayml/stable-diffusion-v1-5"
scheduler = EulerDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
scheduler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
# or
# scheduler = DPMSolverMultistepScheduler.from_config(repo_id, subfolder="scheduler")
# scheduler = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler")
stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, scheduler=scheduler)
```
Three things are worth paying attention to here.
- First, the scheduler is loaded with [`ConfigMixin.from_config`] since it only depends on a configuration file and not any parameterized weights
- First, the scheduler is loaded with [`SchedulerMixin.from_pretrained`]
- Second, the scheduler is loaded with a function argument, called `subfolder="scheduler"` as the configuration of stable diffusion's scheduling is defined in a [subfolder of the official pipeline repository](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/scheduler)
- Third, the scheduler instance can simply be passed with the `scheduler` keyword argument to [`DiffusionPipeline.from_pretrained`]. This works because the [`StableDiffusionPipeline`] defines its scheduler with the `scheduler` attribute. It's not possible to use a different name, such as `sampler=scheduler` since `sampler` is not a defined keyword for [`StableDiffusionPipeline.__init__`]
@@ -337,8 +337,8 @@ model = UNet2DModel.from_pretrained(repo_id)
## Loading schedulers
Schedulers cannot be loaded via a `from_pretrained` method, but instead rely on [`ConfigMixin.from_config`]. Schedulers are **not parameterized** or **trained**, but instead purely defined by a configuration file.
Therefore the loading method was given a different name here.
Schedulers rely on [`SchedulerMixin.from_pretrained`]. Schedulers are **not parameterized** or **trained**, but instead purely defined by a configuration file.
For consistency, we use the same method name as we do for models or pipelines, but no weights are loaded in this case.
In constrast to pipelines or models, loading schedulers does not consume any significant amount of memory and the same configuration file can often be used for a variety of different schedulers.
For example, all of:
@@ -367,13 +367,13 @@ from diffusers import (
repo_id = "runwayml/stable-diffusion-v1-5"
ddpm = DDPMScheduler.from_config(repo_id, subfolder="scheduler")
ddim = DDIMScheduler.from_config(repo_id, subfolder="scheduler")
pndm = PNDMScheduler.from_config(repo_id, subfolder="scheduler")
lms = LMSDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
euler_anc = EulerAncestralDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
euler = EulerDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
dpm = DPMSolverMultistepScheduler.from_config(repo_id, subfolder="scheduler")
ddpm = DDPMScheduler.from_pretrained(repo_id, subfolder="scheduler")
ddim = DDIMScheduler.from_pretrained(repo_id, subfolder="scheduler")
pndm = PNDMScheduler.from_pretrained(repo_id, subfolder="scheduler")
lms = LMSDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
euler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
dpm = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler")
# replace `dpm` with any of `ddpm`, `ddim`, `pndm`, `lms`, `euler`, `euler_anc`
pipeline = StableDiffusionPipeline.from_pretrained(repo_id, scheduler=dpm)

View File

@@ -0,0 +1,20 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Using Diffusers with other modalities
Diffusers is in the process of expanding to modalities other than images.
Currently, one example is for [molecule conformation](https://www.nature.com/subjects/molecular-conformation#:~:text=Definition,to%20changes%20in%20their%20environment.) generation.
* Generate conformations in Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/geodiff_molecule_conformation.ipynb)
More coming soon!

View File

@@ -0,0 +1,18 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Using Diffusers for reinforcement learning
Support for one RL model and related pipelines is included in the `experimental` source of diffusers.
To try some of this in colab, please look at the following example:
* Model-based reinforcement learning on Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_with_diffusers.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)

View File

@@ -0,0 +1,262 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Schedulers
Diffusion pipelines are inherently a collection of diffusion models and schedulers that are partly independent from each other. This means that one is able to switch out parts of the pipeline to better customize
a pipeline to one's use case. The best example of this are the [Schedulers](../api/schedulers.mdx).
Whereas diffusion models usually simply define the forward pass from noise to a less noisy sample,
schedulers define the whole denoising process, *i.e.*:
- How many denoising steps?
- Stochastic or deterministic?
- What algorithm to use to find the denoised sample
They can be quite complex and often define a trade-off between **denoising speed** and **denoising quality**.
It is extremely difficult to measure quantitatively which scheduler works best for a given diffusion pipeline, so it is often recommended to simply try out which works best.
The following paragraphs shows how to do so with the 🧨 Diffusers library.
## Load pipeline
Let's start by loading the stable diffusion pipeline.
Remember that you have to be a registered user on the 🤗 Hugging Face Hub, and have "click-accepted" the [license](https://huggingface.co/runwayml/stable-diffusion-v1-5) in order to use stable diffusion.
```python
from huggingface_hub import login
from diffusers import DiffusionPipeline
import torch
# first we need to login with our access token
login()
# Now we can download the pipeline
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
```
Next, we move it to GPU:
```python
pipeline.to("cuda")
```
## Access the scheduler
The scheduler is always one of the components of the pipeline and is usually called `"scheduler"`.
So it can be accessed via the `"scheduler"` property.
```python
pipeline.scheduler
```
**Output**:
```
PNDMScheduler {
"_class_name": "PNDMScheduler",
"_diffusers_version": "0.8.0.dev0",
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"clip_sample": false,
"num_train_timesteps": 1000,
"set_alpha_to_one": false,
"skip_prk_steps": true,
"steps_offset": 1,
"trained_betas": null
}
```
We can see that the scheduler is of type [`PNDMScheduler`].
Cool, now let's compare the scheduler in its performance to other schedulers.
First we define a prompt on which we will test all the different schedulers:
```python
prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition."
```
Next, we create a generator from a random seed that will ensure that we can generate similar images as well as run the pipeline:
```python
generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator).images[0]
image
```
<p align="center">
<br>
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_pndm.png" width="400"/>
<br>
</p>
## Changing the scheduler
Now we show how easy it is to change the scheduler of a pipeline. Every scheduler has a property [`SchedulerMixin.compatibles`]
which defines all compatible schedulers. You can take a look at all available, compatible schedulers for the Stable Diffusion pipeline as follows.
```python
pipeline.scheduler.compatibles
```
**Output**:
```
[diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,
diffusers.schedulers.scheduling_ddim.DDIMScheduler,
diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,
diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,
diffusers.schedulers.scheduling_pndm.PNDMScheduler,
diffusers.schedulers.scheduling_ddpm.DDPMScheduler,
diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler]
```
Cool, lots of schedulers to look at. Feel free to have a look at their respective class definitions:
- [`LMSDiscreteScheduler`],
- [`DDIMScheduler`],
- [`DPMSolverMultistepScheduler`],
- [`EulerDiscreteScheduler`],
- [`PNDMScheduler`],
- [`DDPMScheduler`],
- [`EulerAncestralDiscreteScheduler`].
We will now compare the input prompt with all other schedulers. To change the scheduler of the pipeline you can make use of the
convenient [`ConfigMixin.config`] property in combination with the [`ConfigMixin.from_config`] function.
```python
pipeline.scheduler.config
```
returns a dictionary of the configuration of the scheduler:
**Output**:
```
FrozenDict([('num_train_timesteps', 1000),
('beta_start', 0.00085),
('beta_end', 0.012),
('beta_schedule', 'scaled_linear'),
('trained_betas', None),
('skip_prk_steps', True),
('set_alpha_to_one', False),
('steps_offset', 1),
('_class_name', 'PNDMScheduler'),
('_diffusers_version', '0.8.0.dev0'),
('clip_sample', False)])
```
This configuration can then be used to instantiate a scheduler
of a different class that is compatible with the pipeline. Here,
we change the scheduler to the [`DDIMScheduler`].
```python
from diffusers import DDIMScheduler
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
```
Cool, now we can run the pipeline again to compare the generation quality.
```python
generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator).images[0]
image
```
<p align="center">
<br>
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_ddim.png" width="400"/>
<br>
</p>
## Compare schedulers
So far we have tried running the stable diffusion pipeline with two schedulers: [`PNDMScheduler`] and [`DDIMScheduler`].
A number of better schedulers have been released that can be run with much fewer steps, let's compare them here:
[`LMSDiscreteScheduler`] usually leads to better results:
```python
from diffusers import LMSDiscreteScheduler
pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator).images[0]
image
```
<p align="center">
<br>
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_lms.png" width="400"/>
<br>
</p>
[`EulerDiscreteScheduler`] and [`EulerAncestralDiscreteScheduler`] can generate high quality results with as little as 30 steps.
```python
from diffusers import EulerDiscreteScheduler
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator, num_inference_steps=30).images[0]
image
```
<p align="center">
<br>
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_euler_discrete.png" width="400"/>
<br>
</p>
and:
```python
from diffusers import EulerAncestralDiscreteScheduler
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator, num_inference_steps=30).images[0]
image
```
<p align="center">
<br>
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_euler_ancestral.png" width="400"/>
<br>
</p>
At the time of writing this doc [`DPMSolverMultistepScheduler`] gives arguably the best speed/quality trade-off and can be run with as little
as 20 steps.
```python
from diffusers import DPMSolverMultistepScheduler
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
generator = torch.Generator(device="cuda").manual_seed(8)
image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]
image
```
<p align="center">
<br>
<img src="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/diffusers_docs/astronaut_dpm.png" width="400"/>
<br>
</p>
As you can see most images look very similar and are arguably of very similar quality. It often really depends on the specific use case which scheduler to choose. A good approach is always to run multiple different
schedulers to compare results.

View File

@@ -15,11 +15,14 @@ If a community doesn't work as expected, please open an issue and ping the autho
| Long Prompt Weighting Stable Diffusion | **One** Stable Diffusion Pipeline without tokens length limit, and support parsing weighting in prompt. | [Long Prompt Weighting Stable Diffusion](#long-prompt-weighting-stable-diffusion) | - | [SkyTNT](https://github.com/SkyTNT) |
| Speech to Image | Using automatic-speech-recognition to transcribe text and Stable Diffusion to generate images | [Speech to Image](#speech-to-image) | - | [Mikail Duzenli](https://github.com/MikailINTech)
| Wild Card Stable Diffusion | Stable Diffusion Pipeline that supports prompts that contain wildcard terms (indicated by surrounding double underscores), with values instantiated randomly from a corresponding txt file or a dictionary of possible values | [Wildcard Stable Diffusion](#wildcard-stable-diffusion) | - | [Shyam Sudhakaran](https://github.com/shyamsn97) |
| Composable Stable Diffusion| Stable Diffusion Pipeline that supports prompts that contain "&#124;" in prompts (as an AND condition) and weights (separated by "&#124;" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
| [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain "&#124;" in prompts (as an AND condition) and weights (separated by "&#124;" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
| Seed Resizing Stable Diffusion| Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | - | [Mark Rich](https://github.com/MarkRich) |
| Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image| [Imagic Stable Diffusion](#imagic-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
| Multilingual Stable Diffusion| Stable Diffusion Pipeline that supports prompts in 50 different languages. | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline) | - | [Juan Carlos Piñeros](https://github.com/juancopi81) |
| Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting| [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Alex McKinney](https://github.com/vvvm23) |
| Text Based Inpainting Stable Diffusion | Stable Diffusion Inpainting Pipeline that enables passing a text prompt to generate the mask for inpainting| [Text Based Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Dhruv Karan](https://github.com/unography) |
| Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - |[Stuti R.](https://github.com/kingstut) |
| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
@@ -342,9 +345,10 @@ out = pipe(
)
```
### Composable Stable diffusion
[Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) proposes conjunction and negation (negative prompts) operators for compositional generation with conditional diffusion models.
```python
import torch as th
import numpy as np
@@ -367,7 +371,7 @@ def dummy(images, **kwargs):
pipe.safety_checker = dummy
images = []
generator = th.Generator("cuda").manual_seed(0)
generator = torch.Generator("cuda").manual_seed(0)
seed = 0
prompt = "a forest | a camel"
@@ -396,6 +400,7 @@ import requests
from PIL import Image
from io import BytesIO
import torch
import os
from diffusers import DiffusionPipeline, DDIMScheduler
has_cuda = torch.cuda.is_available()
device = torch.device('cpu' if not has_cuda else 'cuda')
@@ -420,6 +425,7 @@ res = pipe.train(
num_inference_steps=50,
generator=generator)
res = pipe(alpha=1)
os.makedirs("imagic", exist_ok=True)
image = res.images[0]
image.save('./imagic/imagic_image_alpha_1.png')
res = pipe(alpha=1.5)
@@ -596,7 +602,7 @@ For example, this could be used to place a logo on a shirt and make it blend sea
import PIL
import torch
from diffusers import StableDiffusionInpaintPipeline
from diffusers import DiffusionPipeline
image_path = "./path-to-image.png"
inner_image_path = "./path-to-inner-image.png"
@@ -606,13 +612,120 @@ init_image = PIL.Image.open(image_path).convert("RGB").resize((512, 512))
inner_image = PIL.Image.open(inner_image_path).convert("RGBA").resize((512, 512))
mask_image = PIL.Image.open(mask_path).convert("RGB").resize((512, 512))
pipe = StableDiffusionInpaintPipeline.from_pretrained(
pipe = DiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
custom_pipeline="img2img_inpainting",
revision="fp16",
torch_dtype=torch.float16,
torch_dtype=torch.float16
)
pipe = pipe.to("cuda")
prompt = "Your prompt here!"
image = pipe(prompt=prompt, image=init_image, inner_image=inner_image, mask_image=mask_image).images[0]
```
![2 by 2 grid demonstrating image to image inpainting.](https://user-images.githubusercontent.com/44398246/203506577-ec303be4-887e-4ebd-a773-c83fcb3dd01a.png)
### Text Based Inpainting Stable Diffusion
Use a text prompt to generate the mask for the area to be inpainted.
Currently uses the CLIPSeg model for mask generation, then calls the standard Stable Diffusion Inpainting pipeline to perform the inpainting.
```python
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
from diffusers import DiffusionPipeline
from PIL import Image
import requests
from torch import autocast
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
pipe = DiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
custom_pipeline="text_inpainting",
segmentation_model=model,
segmentation_processor=processor
)
pipe = pipe.to("cuda")
url = "https://github.com/timojl/clipseg/blob/master/example_image.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw).resize((512, 512))
text = "a glass" # will mask out this text
prompt = "a cup" # the masked out region will be replaced with this
with autocast("cuda"):
image = pipe(image=image, text=text, prompt=prompt).images[0]
```
### Bit Diffusion
Based https://arxiv.org/abs/2208.04202, this is used for diffusion on discrete data - eg, discreate image data, DNA sequence data. An unconditional discreate image can be generated like this:
```python
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="bit_diffusion")
image = pipe().images[0]
```
### Stable Diffusion with K Diffusion
Make sure you have @crowsonkb's https://github.com/crowsonkb/k-diffusion installed:
```
pip install k-diffusion
```
You can use the community pipeline as follows:
```python
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="sd_text2img_k_diffusion")
pipe = pipe.to("cuda")
prompt = "an astronaut riding a horse on mars"
pipe.set_sampler("sample_heun")
generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
image.save("./astronaut_heun_k_diffusion.png")
```
To make sure that K Diffusion and `diffusers` yield the same results:
**Diffusers**:
```python
from diffusers import DiffusionPipeline, EulerDiscreteScheduler
seed = 33
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(prompt, generator=generator, num_inference_steps=50).images[0]
```
![diffusers_euler](https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/k_diffusion/astronaut_euler.png)
**K Diffusion**:
```python
from diffusers import DiffusionPipeline, EulerDiscreteScheduler
seed = 33
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="sd_text2img_k_diffusion")
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
pipe.set_sampler("sample_euler")
generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(prompt, generator=generator, num_inference_steps=50).images[0]
```
![diffusers_euler](https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/k_diffusion/astronaut_euler_k_diffusion.png)

View File

@@ -0,0 +1,265 @@
from typing import Optional, Tuple, Union
import torch
from diffusers import DDIMScheduler, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.pipeline_utils import ImagePipelineOutput
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
from diffusers.schedulers.scheduling_ddpm import DDPMSchedulerOutput
from einops import rearrange, reduce
BITS = 8
# convert to bit representations and back taken from https://github.com/lucidrains/bit-diffusion/blob/main/bit_diffusion/bit_diffusion.py
def decimal_to_bits(x, bits=BITS):
"""expects image tensor ranging from 0 to 1, outputs bit tensor ranging from -1 to 1"""
device = x.device
x = (x * 255).int().clamp(0, 255)
mask = 2 ** torch.arange(bits - 1, -1, -1, device=device)
mask = rearrange(mask, "d -> d 1 1")
x = rearrange(x, "b c h w -> b c 1 h w")
bits = ((x & mask) != 0).float()
bits = rearrange(bits, "b c d h w -> b (c d) h w")
bits = bits * 2 - 1
return bits
def bits_to_decimal(x, bits=BITS):
"""expects bits from -1 to 1, outputs image tensor from 0 to 1"""
device = x.device
x = (x > 0).int()
mask = 2 ** torch.arange(bits - 1, -1, -1, device=device, dtype=torch.int32)
mask = rearrange(mask, "d -> d 1 1")
x = rearrange(x, "b (c d) h w -> b c d h w", d=8)
dec = reduce(x * mask, "b c d h w -> b c h w", "sum")
return (dec / 255).clamp(0.0, 1.0)
# modified scheduler step functions for clamping the predicted x_0 between -bit_scale and +bit_scale
def ddim_bit_scheduler_step(
self,
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
eta: float = 0.0,
use_clipped_model_output: bool = True,
generator=None,
return_dict: bool = True,
) -> Union[DDIMSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): TODO
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
Returns:
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_sample -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_sample_direction -> "direction pointing to x_t"
# - pred_prev_sample -> "x_t-1"
# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
# 4. Clip "predicted x_0"
scale = self.bit_scale
if self.config.clip_sample:
pred_original_sample = torch.clamp(pred_original_sample, -scale, scale)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 α_t1)/(1 α_t)) * sqrt(1 α_t/α_t1)
variance = self._get_variance(timestep, prev_timestep)
std_dev_t = eta * variance ** (0.5)
if use_clipped_model_output:
# the model_output is always re-derived from the clipped x_0 in Glide
model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if eta > 0:
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
device = model_output.device if torch.is_tensor(model_output) else "cpu"
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
prev_sample = prev_sample + variance
if not return_dict:
return (prev_sample,)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
def ddpm_bit_scheduler_step(
self,
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
prediction_type="epsilon",
generator=None,
return_dict: bool = True,
) -> Union[DDPMSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the samples (`sample`).
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
Returns:
[`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
t = timestep
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
else:
predicted_variance = None
# 1. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif prediction_type == "sample":
pred_original_sample = model_output
else:
raise ValueError(f"Unsupported prediction_type {prediction_type}.")
# 3. Clip "predicted x_0"
scale = self.bit_scale
if self.config.clip_sample:
pred_original_sample = torch.clamp(pred_original_sample, -scale, scale)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
# 5. Compute predicted previous sample µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
# 6. Add noise
variance = 0
if t > 0:
noise = torch.randn(
model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator
).to(model_output.device)
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
pred_prev_sample = pred_prev_sample + variance
if not return_dict:
return (pred_prev_sample,)
return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
class BitDiffusion(DiffusionPipeline):
def __init__(
self,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, DDPMScheduler],
bit_scale: Optional[float] = 1.0,
):
super().__init__()
self.bit_scale = bit_scale
self.scheduler.step = (
ddim_bit_scheduler_step if isinstance(scheduler, DDIMScheduler) else ddpm_bit_scheduler_step
)
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(
self,
height: Optional[int] = 256,
width: Optional[int] = 256,
num_inference_steps: Optional[int] = 50,
generator: Optional[torch.Generator] = None,
batch_size: Optional[int] = 1,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
) -> Union[Tuple, ImagePipelineOutput]:
latents = torch.randn(
(batch_size, self.unet.in_channels, height, width),
generator=generator,
)
latents = decimal_to_bits(latents) * self.bit_scale
latents = latents.to(self.device)
self.scheduler.set_timesteps(num_inference_steps)
for t in self.progress_bar(self.scheduler.timesteps):
# predict the noise residual
noise_pred = self.unet(latents, t).sample
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
image = bits_to_decimal(latents)
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)

View File

@@ -78,7 +78,12 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
)
self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
self.make_cutouts = MakeCutouts(feature_extractor.size)
cut_out_size = (
feature_extractor.size
if isinstance(feature_extractor.size, int)
else feature_extractor.size["shortest_edge"]
)
self.make_cutouts = MakeCutouts(cut_out_size)
set_requires_grad(self.text_encoder, False)
set_requires_grad(self.clip_model, False)

View File

@@ -18,17 +18,38 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import logging
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = {
"linear": PIL.Image.Resampling.BILINEAR,
"bilinear": PIL.Image.Resampling.BILINEAR,
"bicubic": PIL.Image.Resampling.BICUBIC,
"lanczos": PIL.Image.Resampling.LANCZOS,
"nearest": PIL.Image.Resampling.NEAREST,
}
else:
PIL_INTERPOLATION = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
"nearest": PIL.Image.NEAREST,
}
# ------------------------------------------------------------------------------
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)

View File

@@ -110,7 +110,7 @@ class ImageToImageInpaintingPipeline(DiffusionPipeline):
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"

View File

@@ -101,7 +101,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"

View File

@@ -13,9 +13,31 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import deprecate, is_accelerate_available, logging
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = {
"linear": PIL.Image.Resampling.BILINEAR,
"bilinear": PIL.Image.Resampling.BILINEAR,
"bicubic": PIL.Image.Resampling.BICUBIC,
"lanczos": PIL.Image.Resampling.LANCZOS,
"nearest": PIL.Image.Resampling.NEAREST,
}
else:
PIL_INTERPOLATION = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
"nearest": PIL.Image.NEAREST,
}
# ------------------------------------------------------------------------------
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
re_attention = re.compile(
@@ -358,7 +380,7 @@ def get_weighted_text_embeddings(
def preprocess_image(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
@@ -369,7 +391,7 @@ def preprocess_mask(mask):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"])
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
@@ -447,7 +469,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"

View File

@@ -11,9 +11,30 @@ from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import logging
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTokenizer
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = {
"linear": PIL.Image.Resampling.BILINEAR,
"bilinear": PIL.Image.Resampling.BILINEAR,
"bicubic": PIL.Image.Resampling.BICUBIC,
"lanczos": PIL.Image.Resampling.LANCZOS,
"nearest": PIL.Image.Resampling.NEAREST,
}
else:
PIL_INTERPOLATION = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
"nearest": PIL.Image.NEAREST,
}
# ------------------------------------------------------------------------------
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
re_attention = re.compile(
@@ -365,7 +386,7 @@ def get_weighted_text_embeddings(
def preprocess_image(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
return 2.0 * image - 1.0
@@ -375,7 +396,7 @@ def preprocess_mask(mask):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"])
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?

View File

@@ -113,7 +113,7 @@ class MultilingualStableDiffusion(DiffusionPipeline):
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"

View File

@@ -0,0 +1,479 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from typing import Callable, List, Optional, Union
import torch
from diffusers import LMSDiscreteScheduler
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.utils import is_accelerate_available, logging
from k_diffusion.external import CompVisDenoiser
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class ModelWrapper:
def __init__(self, model, alphas_cumprod):
self.model = model
self.alphas_cumprod = alphas_cumprod
def apply_model(self, *args, **kwargs):
return self.model(*args, **kwargs).sample
class StableDiffusionPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
def __init__(
self,
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker,
feature_extractor,
):
super().__init__()
if safety_checker is None:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
# get correct sigmas from LMS
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
model = ModelWrapper(unet, scheduler.alphas_cumprod)
self.k_diffusion_model = CompVisDenoiser(model)
def set_sampler(self, scheduler_type: str):
library = importlib.import_module("k_diffusion")
sampling = getattr(library, "sampling")
self.sampler = getattr(sampling, scheduler_type)
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
@property
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
return text_embeddings
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
else:
has_nsfw_concept = None
return image, has_nsfw_concept
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
def check_inputs(self, prompt, height, width, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // 8, width // 8)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
return latents
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = True
if guidance_scale <= 1.0:
raise ValueError("has to use guidance_scale")
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=text_embeddings.device)
sigmas = self.scheduler.sigmas
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
text_embeddings.dtype,
device,
generator,
latents,
)
latents = latents * sigmas[0]
self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device)
def model_fn(x, t):
latent_model_input = torch.cat([x] * 2)
noise_pred = self.k_diffusion_model(latent_model_input, t, encoder_hidden_states=text_embeddings)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
return noise_pred
latents = self.sampler(model_fn, latents, sigmas)
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
# 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

View File

@@ -42,7 +42,7 @@ class SpeechToImagePipeline(DiffusionPipeline):
super().__init__()
if safety_checker is None:
logger.warn(
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"

View File

@@ -0,0 +1,320 @@
from typing import Callable, List, Optional, Union
import torch
import PIL
from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import deprecate, is_accelerate_available, logging
from transformers import (
CLIPFeatureExtractor,
CLIPSegForImageSegmentation,
CLIPSegProcessor,
CLIPTextModel,
CLIPTokenizer,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class TextInpainting(DiffusionPipeline):
r"""
Pipeline for text based inpainting using Stable Diffusion.
Uses CLIPSeg to get a mask from the given text, then calls the Inpainting pipeline with the generated mask
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
segmentation_model ([`CLIPSegForImageSegmentation`]):
CLIPSeg Model to generate mask from the given text. Please refer to the [model card]() for details.
segmentation_processor ([`CLIPSegProcessor`]):
CLIPSeg processor to get image, text features to translate prompt to English, if necessary. Please refer to the
[model card](https://huggingface.co/docs/transformers/model_doc/clipseg) for details.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
def __init__(
self,
segmentation_model: CLIPSegForImageSegmentation,
segmentation_processor: CLIPSegProcessor,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
" sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to"
" incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face"
" Hub, it would be very nice if you could open a Pull request for the"
" `scheduler/scheduler_config.json` file"
)
deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["skip_prk_steps"] = True
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
self.register_modules(
segmentation_model=segmentation_model,
segmentation_processor=segmentation_processor,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device("cuda")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
image: Union[torch.FloatTensor, PIL.Image.Image],
text: str,
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
image (`PIL.Image.Image`):
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
be masked out with `mask_image` and repainted according to `prompt`.
text (`str``):
The text to use to generate the mask.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# We use the input text to generate the mask
inputs = self.segmentation_processor(
text=[text], images=[image], padding="max_length", return_tensors="pt"
).to(self.device)
outputs = self.segmentation_model(**inputs)
mask = torch.sigmoid(outputs.logits).cpu().detach().unsqueeze(-1).numpy()
mask_pil = self.numpy_to_pil(mask)[0].resize(image.size)
# Run inpainting pipeline with the generated mask
inpainting_pipeline = StableDiffusionInpaintPipeline(
vae=self.vae,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
unet=self.unet,
scheduler=self.scheduler,
safety_checker=self.safety_checker,
feature_extractor=self.feature_extractor,
)
return inpainting_pipeline(
prompt=prompt,
image=image,
mask_image=mask_pil,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
eta=eta,
generator=generator,
latents=latents,
output_type=output_type,
return_dict=return_dict,
callback=callback,
callback_steps=callback_steps,
)

View File

@@ -135,7 +135,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline):
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"

View File

@@ -39,6 +39,8 @@ Now let's get our dataset. Download images from [here](https://drive.google.com/
And launch the training using
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
@@ -92,7 +94,7 @@ accelerate launch train_dreambooth.py \
With the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU.
Install `bitsandbytes` with `pip install bitsandbytes`
To install `bitandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation).
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
@@ -141,7 +143,7 @@ export INSTANCE_DIR="path-to-instance-images"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"
accelerate launch train_dreambooth.py \
accelerate launch --mixed_precision="fp16" train_dreambooth.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--class_data_dir=$CLASS_DIR \
@@ -157,8 +159,7 @@ accelerate launch train_dreambooth.py \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
--max_train_steps=800 \
--mixed_precision=fp16
--max_train_steps=800
```
### Fine-tune text encoder with the UNet.

View File

@@ -124,6 +124,7 @@ def parse_args(input_args=None):
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
@@ -187,12 +188,12 @@ def parse_args(input_args=None):
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
@@ -472,7 +473,7 @@ def main(args):
eps=args.adam_epsilon,
)
noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
@@ -538,9 +539,9 @@ def main(args):
)
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move text_encode and vae to gpu.
@@ -603,23 +604,31 @@ def main(args):
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
if args.with_prior_preservation:
# Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
noise, noise_prior = torch.chunk(noise, 2, dim=0)
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
# Compute instance loss
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
# Compute prior loss
prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:
@@ -638,6 +647,17 @@ def main(args):
progress_bar.update(1)
global_step += 1
if global_step % args.save_steps == 0:
if accelerator.is_main_process:
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
revision=args.revision,
)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
pipeline.save_pretrained(save_path)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)

View File

@@ -42,11 +42,13 @@ If you have already cloned the repo, then you won't need to go through these ste
#### Hardware
With `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with >30GB memory.
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/pokemon-blip-captions"
accelerate launch train_text_to_image.py \
accelerate launch --mixed_precision="fp16" train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \
--use_ema \
@@ -54,7 +56,6 @@ accelerate launch train_text_to_image.py \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
@@ -70,7 +71,7 @@ If you wish to use custom loading logic, you should modify the script, we have l
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export TRAIN_DIR="path_to_your_dataset"
accelerate launch train_text_to_image.py \
accelerate launch --mixed_precision="fp16" train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$TRAIN_DIR \
--use_ema \
@@ -78,7 +79,6 @@ accelerate launch train_text_to_image.py \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \

View File

@@ -15,13 +15,12 @@ from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from huggingface_hub import HfFolder, Repository, whoami
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPTextModel, CLIPTokenizer
logger = get_logger(__name__)
@@ -36,6 +35,13 @@ def parse_args():
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--dataset_name",
type=str,
@@ -186,12 +192,12 @@ def parse_args():
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
@@ -335,10 +341,24 @@ def main():
os.makedirs(args.output_dir, exist_ok=True)
# Load models and create wrapper for stable diffusion
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
tokenizer = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
)
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
)
# Freeze vae and text_encoder
vae.requires_grad_(False)
@@ -372,7 +392,7 @@ def main():
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
# Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
@@ -496,9 +516,9 @@ def main():
)
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move text_encode and vae to gpu.
@@ -562,9 +582,17 @@ def main():
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Predict the noise residual and compute loss
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
@@ -600,14 +628,12 @@ def main():
if args.use_ema:
ema_unet.copy_to(unet.parameters())
pipeline = StableDiffusionPipeline(
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=PNDMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler"),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
revision=args.revision,
)
pipeline.save_pretrained(args.output_dir)

View File

@@ -47,6 +47,8 @@ Now let's get our dataset.Download 3-4 images from [here](https://drive.google.c
And launch the training using
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
```bash
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export DATA_DIR="path-to-dir-containing-images"

View File

@@ -16,24 +16,45 @@ import PIL
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from huggingface_hub import HfFolder, Repository, whoami
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPTextModel, CLIPTokenizer
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = {
"linear": PIL.Image.Resampling.BILINEAR,
"bilinear": PIL.Image.Resampling.BILINEAR,
"bicubic": PIL.Image.Resampling.BICUBIC,
"lanczos": PIL.Image.Resampling.LANCZOS,
"nearest": PIL.Image.Resampling.NEAREST,
}
else:
PIL_INTERPOLATION = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
"nearest": PIL.Image.NEAREST,
}
# ------------------------------------------------------------------------------
logger = get_logger(__name__)
def save_progress(text_encoder, placeholder_token_id, accelerator, args):
def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
logger.info("Saving embeddings")
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin"))
torch.save(learned_embeds_dict, save_path)
def parse_args():
@@ -51,6 +72,13 @@ def parse_args():
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--tokenizer_name",
type=str,
@@ -260,10 +288,10 @@ class TextualInversionDataset(Dataset):
self._length = self.num_images * repeats
self.interpolation = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
"linear": PIL_INTERPOLATION["linear"],
"bilinear": PIL_INTERPOLATION["bilinear"],
"bicubic": PIL_INTERPOLATION["bicubic"],
"lanczos": PIL_INTERPOLATION["lanczos"],
}[interpolation]
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
@@ -383,9 +411,21 @@ def main():
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
# Load models and create wrapper for stable diffusion
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
)
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
@@ -419,7 +459,7 @@ def main():
eps=args.adam_epsilon,
)
noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
train_dataset = TextualInversionDataset(
data_root=args.train_data_dir,
@@ -510,9 +550,17 @@ def main():
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = F.mse_loss(model_pred, target, reduction="none").mean([1, 2, 3]).mean()
accelerator.backward(loss)
# Zero out the gradients for all token embeddings except the newly added
@@ -534,7 +582,8 @@ def main():
progress_bar.update(1)
global_step += 1
if global_step % args.save_steps == 0:
save_progress(text_encoder, placeholder_token_id, accelerator, args)
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
@@ -547,18 +596,18 @@ def main():
# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
pipeline = StableDiffusionPipeline(
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder),
tokenizer=tokenizer,
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=PNDMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler"),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
revision=args.revision,
)
pipeline.save_pretrained(args.output_dir)
# Also save the newly trained embeddings
save_progress(text_encoder, placeholder_token_id, accelerator, args)
save_path = os.path.join(args.output_dir, "learned_embeds.bin")
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)

View File

@@ -28,12 +28,33 @@ from flax import jax_utils
from flax.training import train_state
from flax.training.common_utils import shard
from huggingface_hub import HfFolder, Repository, whoami
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
PIL_INTERPOLATION = {
"linear": PIL.Image.Resampling.BILINEAR,
"bilinear": PIL.Image.Resampling.BILINEAR,
"bicubic": PIL.Image.Resampling.BICUBIC,
"lanczos": PIL.Image.Resampling.LANCZOS,
"nearest": PIL.Image.Resampling.NEAREST,
}
else:
PIL_INTERPOLATION = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
"nearest": PIL.Image.NEAREST,
}
# ------------------------------------------------------------------------------
logger = logging.getLogger(__name__)
@@ -246,10 +267,10 @@ class TextualInversionDataset(Dataset):
self._length = self.num_images * repeats
self.interpolation = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
"linear": PIL_INTERPOLATION["linear"],
"bilinear": PIL_INTERPOLATION["bilinear"],
"bicubic": PIL_INTERPOLATION["bicubic"],
"lanczos": PIL_INTERPOLATION["lanczos"],
}[interpolation]
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small

View File

@@ -127,3 +127,24 @@ dataset.push_to_hub("name_of_your_dataset", private=True)
and that's it! You can now train your model by simply setting the `--dataset_name` argument to the name of your dataset on the hub.
More on this can also be found in [this blog post](https://huggingface.co/blog/image-search-datasets).
#### Use ONNXRuntime to accelerate training
In order to leverage onnxruntime to accelerate training, please use train_unconditional_ort.py
The command to train a DDPM UNet model on the Oxford Flowers dataset with onnxruntime:
```bash
accelerate launch train_unconditional_ort.py \
--dataset_name="huggan/flowers-102-categories" \
--resolution=64 \
--output_dir="ddpm-ema-flowers-64" \
--train_batch_size=16 \
--num_epochs=1 \
--gradient_accumulation_steps=1 \
--learning_rate=1e-4 \
--lr_warmup_steps=500 \
--mixed_precision=fp16
```
Please contact Prathik Rao (prathikr), Sunghoon Choi (hanbitmyths), Ashwini Khade (askhade), or Peng Wang (pengwa) on github with any questions.

View File

@@ -194,9 +194,10 @@ def parse_args():
)
parser.add_argument(
"--predict_epsilon",
action="store_true",
default=True,
"--prediction_type",
type=str,
default="epsilon",
choices=["epsilon", "sample"],
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
)
@@ -256,13 +257,13 @@ def main(args):
"UpBlock2D",
),
)
accepts_predict_epsilon = "predict_epsilon" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
if accepts_predict_epsilon:
if accepts_prediction_type:
noise_scheduler = DDPMScheduler(
num_train_timesteps=args.ddpm_num_steps,
beta_schedule=args.ddpm_beta_schedule,
predict_epsilon=args.predict_epsilon,
prediction_type=args.prediction_type,
)
else:
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
@@ -319,7 +320,12 @@ def main(args):
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)
ema_model = EMAModel(
accelerator.unwrap_model(model),
inv_gamma=args.ema_inv_gamma,
power=args.ema_power,
max_value=args.ema_max_decay,
)
# Handle the repository creation
if accelerator.is_main_process:
@@ -365,9 +371,9 @@ def main(args):
# Predict the noise residual
model_output = model(noisy_images, timesteps).sample
if args.predict_epsilon:
if args.prediction_type == "epsilon":
loss = F.mse_loss(model_output, noise) # this could have different weights!
else:
elif args.prediction_type == "sample":
alpha_t = _extract_into_tensor(
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
)
@@ -376,6 +382,8 @@ def main(args):
model_output, clean_images, reduction="none"
) # use SNR weighting from distillation paper
loss = loss.mean()
else:
raise ValueError(f"Unsupported prediction type: {args.prediction_type}")
accelerator.backward(loss)

View File

@@ -0,0 +1,251 @@
import argparse
import math
import os
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.logging import get_logger
from datasets import load_dataset
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from onnxruntime.training.ortmodule import ORTModule
from torchvision.transforms import (
CenterCrop,
Compose,
InterpolationMode,
Normalize,
RandomHorizontalFlip,
Resize,
ToTensor,
)
from tqdm.auto import tqdm
logger = get_logger(__name__)
def main(args):
logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with="tensorboard",
logging_dir=logging_dir,
)
model = UNet2DModel(
sample_size=args.resolution,
in_channels=3,
out_channels=3,
layers_per_block=2,
block_out_channels=(128, 128, 256, 256, 512, 512),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"DownBlock2D",
),
up_block_types=(
"UpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)
model = ORTModule(model)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt")
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
augmentations = Compose(
[
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
CenterCrop(args.resolution),
RandomHorizontalFlip(),
ToTensor(),
Normalize([0.5], [0.5]),
]
)
if args.dataset_name is not None:
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
use_auth_token=True if args.use_auth_token else None,
split="train",
)
else:
dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")
def transforms(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
return {"input": images}
dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True)
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps,
num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
)
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)
if args.push_to_hub:
repo = init_git_repo(args, at_init=True)
if accelerator.is_main_process:
run = os.path.split(__file__)[-1].split(".")[0]
accelerator.init_trackers(run)
global_step = 0
for epoch in range(args.num_epochs):
model.train()
progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)
progress_bar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
clean_images = batch["input"]
# Sample noise that we'll add to the images
noise = torch.randn(clean_images.shape).to(clean_images.device)
bsz = clean_images.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device
).long()
# Add noise to the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
with accelerator.accumulate(model):
# Predict the noise residual
noise_pred = model(noisy_images, timesteps, return_dict=True)[0]
loss = F.mse_loss(noise_pred, noise)
accelerator.backward(loss)
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
if args.use_ema:
ema_model.step(model)
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
if args.use_ema:
logs["ema_decay"] = ema_model.decay
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
progress_bar.close()
accelerator.wait_for_everyone()
# Generate sample images for visual inspection
if accelerator.is_main_process:
if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
pipeline = DDPMPipeline(
unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model),
scheduler=noise_scheduler,
)
generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy").images
# denormalize the images and save to tensorboard
images_processed = (images * 255).round().astype("uint8")
accelerator.trackers[0].writer.add_images(
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
)
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
# save the model
if args.push_to_hub:
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
else:
pipeline.save_pretrained(args.output_dir)
accelerator.wait_for_everyone()
accelerator.end_training()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--dataset_name", type=str, default=None)
parser.add_argument("--dataset_config_name", type=str, default=None)
parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.")
parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
parser.add_argument("--overwrite_output_dir", action="store_true")
parser.add_argument("--cache_dir", type=str, default=None)
parser.add_argument("--resolution", type=int, default=64)
parser.add_argument("--train_batch_size", type=int, default=16)
parser.add_argument("--eval_batch_size", type=int, default=16)
parser.add_argument("--num_epochs", type=int, default=100)
parser.add_argument("--save_images_epochs", type=int, default=10)
parser.add_argument("--save_model_epochs", type=int, default=10)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--lr_scheduler", type=str, default="cosine")
parser.add_argument("--lr_warmup_steps", type=int, default=500)
parser.add_argument("--adam_beta1", type=float, default=0.95)
parser.add_argument("--adam_beta2", type=float, default=0.999)
parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
parser.add_argument("--adam_epsilon", type=float, default=1e-08)
parser.add_argument("--use_ema", action="store_true", default=True)
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
parser.add_argument("--ema_power", type=float, default=3 / 4)
parser.add_argument("--ema_max_decay", type=float, default=0.9999)
parser.add_argument("--push_to_hub", action="store_true")
parser.add_argument("--use_auth_token", action="store_true")
parser.add_argument("--hub_token", type=str, default=None)
parser.add_argument("--hub_model_id", type=str, default=None)
parser.add_argument("--hub_private_repo", action="store_true")
parser.add_argument("--logging_dir", type=str, default="logs")
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
),
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
if args.dataset_name is None and args.train_data_dir is None:
raise ValueError("You must specify either a dataset name from the hub or a train data directory.")
main(args)

View File

@@ -211,6 +211,7 @@ def create_unet_diffusers_config(original_config):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
model_params = original_config.model.params
unet_params = original_config.model.params.unet_config.params
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
@@ -230,7 +231,7 @@ def create_unet_diffusers_config(original_config):
resolution //= 2
config = dict(
sample_size=unet_params.image_size,
sample_size=model_params.image_size,
in_channels=unet_params.in_channels,
out_channels=unet_params.out_channels,
down_block_types=tuple(down_block_types),

View File

@@ -215,8 +215,10 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
)
del pipeline.safety_checker
safety_checker = OnnxRuntimeModel.from_pretrained(output_path / "safety_checker")
feature_extractor = pipeline.feature_extractor
else:
safety_checker = None
feature_extractor = None
onnx_pipeline = OnnxStableDiffusionPipeline(
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
@@ -226,7 +228,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler,
safety_checker=safety_checker,
feature_extractor=pipeline.feature_extractor,
feature_extractor=feature_extractor,
requires_safety_checker=safety_checker is not None,
)
onnx_pipeline.save_pretrained(output_path)

View File

@@ -0,0 +1,791 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Conversion script for the Versatile Stable Diffusion checkpoints. """
import argparse
from argparse import Namespace
import torch
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
UNet2DConditionModel,
VersatileDiffusionPipeline,
)
from diffusers.pipelines.versatile_diffusion.modeling_text_unet import UNetFlatConditionModel
from transformers import (
CLIPFeatureExtractor,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
SCHEDULER_CONFIG = Namespace(
**{
"beta_linear_start": 0.00085,
"beta_linear_end": 0.012,
"timesteps": 1000,
"scale_factor": 0.18215,
}
)
IMAGE_UNET_CONFIG = Namespace(
**{
"input_channels": 4,
"model_channels": 320,
"output_channels": 4,
"num_noattn_blocks": [2, 2, 2, 2],
"channel_mult": [1, 2, 4, 4],
"with_attn": [True, True, True, False],
"num_heads": 8,
"context_dim": 768,
"use_checkpoint": True,
}
)
TEXT_UNET_CONFIG = Namespace(
**{
"input_channels": 768,
"model_channels": 320,
"output_channels": 768,
"num_noattn_blocks": [2, 2, 2, 2],
"channel_mult": [1, 2, 4, 4],
"second_dim": [4, 4, 4, 4],
"with_attn": [True, True, True, False],
"num_heads": 8,
"context_dim": 768,
"use_checkpoint": True,
}
)
AUTOENCODER_CONFIG = Namespace(
**{
"double_z": True,
"z_channels": 4,
"resolution": 256,
"in_channels": 3,
"out_ch": 3,
"ch": 128,
"ch_mult": [1, 2, 4, 4],
"num_res_blocks": 2,
"attn_resolutions": [],
"dropout": 0.0,
}
)
def shave_segments(path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
if n_shave_prefix_segments >= 0:
return ".".join(path.split(".")[n_shave_prefix_segments:])
else:
return ".".join(path.split(".")[:n_shave_prefix_segments])
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item.replace("in_layers.0", "norm1")
new_item = new_item.replace("in_layers.2", "conv1")
new_item = new_item.replace("out_layers.0", "norm2")
new_item = new_item.replace("out_layers.3", "conv2")
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
new_item = new_item.replace("skip_connection", "conv_shortcut")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
new_item = new_item.replace("norm.weight", "group_norm.weight")
new_item = new_item.replace("norm.bias", "group_norm.bias")
new_item = new_item.replace("q.weight", "query.weight")
new_item = new_item.replace("q.bias", "query.bias")
new_item = new_item.replace("k.weight", "key.weight")
new_item = new_item.replace("k.bias", "key.bias")
new_item = new_item.replace("v.weight", "value.weight")
new_item = new_item.replace("v.bias", "value.bias")
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def assign_to_checkpoint(
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
):
"""
This does the final conversion step: take locally converted weights and apply a global renaming
to them. It splits attention layers, and takes into account additional replacements
that may arise.
Assigns the weights to the new checkpoint.
"""
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
# Splits the attention layers into three variables.
if attention_paths_to_split is not None:
for path, path_map in attention_paths_to_split.items():
old_tensor = old_checkpoint[path]
channels = old_tensor.shape[0] // 3
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
query, key, value = old_tensor.split(channels // num_heads, dim=1)
checkpoint[path_map["query"]] = query.reshape(target_shape)
checkpoint[path_map["key"]] = key.reshape(target_shape)
checkpoint[path_map["value"]] = value.reshape(target_shape)
for path in paths:
new_path = path["new"]
# These have already been assigned
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
continue
# Global renaming happens here
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
if additional_replacements is not None:
for replacement in additional_replacements:
new_path = new_path.replace(replacement["old"], replacement["new"])
# proj_attn.weight has to be converted from conv 1D to linear
if "proj_attn.weight" in new_path:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
elif path["old"] in old_checkpoint:
checkpoint[new_path] = old_checkpoint[path["old"]]
def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys())
attn_keys = ["query.weight", "key.weight", "value.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
def create_image_unet_diffusers_config(unet_params):
"""
Creates a config for the diffusers based on the config of the VD model.
"""
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlock2D" if unet_params.with_attn[i] else "DownBlock2D"
down_block_types.append(block_type)
if i != len(block_out_channels) - 1:
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if unet_params.with_attn[-i - 1] else "UpBlock2D"
up_block_types.append(block_type)
resolution //= 2
if not all(n == unet_params.num_noattn_blocks[0] for n in unet_params.num_noattn_blocks):
raise ValueError("Not all num_res_blocks are equal, which is not supported in this script.")
config = dict(
sample_size=None,
in_channels=unet_params.input_channels,
out_channels=unet_params.output_channels,
down_block_types=tuple(down_block_types),
up_block_types=tuple(up_block_types),
block_out_channels=tuple(block_out_channels),
layers_per_block=unet_params.num_noattn_blocks[0],
cross_attention_dim=unet_params.context_dim,
attention_head_dim=unet_params.num_heads,
)
return config
def create_text_unet_diffusers_config(unet_params):
"""
Creates a config for the diffusers based on the config of the VD model.
"""
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlockFlat" if unet_params.with_attn[i] else "DownBlockFlat"
down_block_types.append(block_type)
if i != len(block_out_channels) - 1:
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlockFlat" if unet_params.with_attn[-i - 1] else "UpBlockFlat"
up_block_types.append(block_type)
resolution //= 2
if not all(n == unet_params.num_noattn_blocks[0] for n in unet_params.num_noattn_blocks):
raise ValueError("Not all num_res_blocks are equal, which is not supported in this script.")
config = dict(
sample_size=None,
in_channels=(unet_params.input_channels, 1, 1),
out_channels=(unet_params.output_channels, 1, 1),
down_block_types=tuple(down_block_types),
up_block_types=tuple(up_block_types),
block_out_channels=tuple(block_out_channels),
layers_per_block=unet_params.num_noattn_blocks[0],
cross_attention_dim=unet_params.context_dim,
attention_head_dim=unet_params.num_heads,
)
return config
def create_vae_diffusers_config(vae_params):
"""
Creates a config for the diffusers based on the config of the VD model.
"""
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
config = dict(
sample_size=vae_params.resolution,
in_channels=vae_params.in_channels,
out_channels=vae_params.out_ch,
down_block_types=tuple(down_block_types),
up_block_types=tuple(up_block_types),
block_out_channels=tuple(block_out_channels),
latent_channels=vae_params.z_channels,
layers_per_block=vae_params.num_res_blocks,
)
return config
def create_diffusers_scheduler(original_config):
schedular = DDIMScheduler(
num_train_timesteps=original_config.model.params.timesteps,
beta_start=original_config.model.params.linear_start,
beta_end=original_config.model.params.linear_end,
beta_schedule="scaled_linear",
)
return schedular
def convert_vd_unet_checkpoint(checkpoint, config, unet_key, extract_ema=False):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
# extract state_dict for UNet
unet_state_dict = {}
keys = list(checkpoint.keys())
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100:
print("Checkpoint has both EMA and non-EMA weights.")
if extract_ema:
print(
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
)
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
else:
print(
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
)
for key in keys:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
new_checkpoint = {}
new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["model.diffusion_model.time_embed.0.weight"]
new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["model.diffusion_model.time_embed.0.bias"]
new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["model.diffusion_model.time_embed.2.weight"]
new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["model.diffusion_model.time_embed.2.bias"]
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
# Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
input_blocks = {
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}
# Retrieves the keys for the middle blocks only
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
middle_blocks = {
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks)
}
# Retrieves the keys for the output blocks only
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
output_blocks = {
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
for layer_id in range(num_output_blocks)
}
for i in range(1, num_input_blocks):
block_id = (i - 1) // (config["layers_per_block"] + 1)
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
resnets = [
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
]
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.bias"
)
elif f"input_blocks.{i}.0.weight" in unet_state_dict:
# text_unet uses linear layers in place of downsamplers
shape = unet_state_dict[f"input_blocks.{i}.0.weight"].shape
if shape[0] != shape[1]:
continue
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.weight"] = unet_state_dict.pop(
f"input_blocks.{i}.0.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.bias"] = unet_state_dict.pop(
f"input_blocks.{i}.0.bias"
)
paths = renew_resnet_paths(resnets)
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
resnet_0 = middle_blocks[0]
attentions = middle_blocks[1]
resnet_1 = middle_blocks[2]
resnet_0_paths = renew_resnet_paths(resnet_0)
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
resnet_1_paths = renew_resnet_paths(resnet_1)
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
attentions_paths = renew_attention_paths(attentions)
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
for i in range(num_output_blocks):
block_id = i // (config["layers_per_block"] + 1)
layer_in_block_id = i % (config["layers_per_block"] + 1)
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
output_block_list = {}
for layer in output_block_layers:
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
if layer_id in output_block_list:
output_block_list[layer_id].append(layer_name)
else:
output_block_list[layer_id] = [layer_name]
if len(output_block_list) > 1:
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
paths = renew_resnet_paths(resnets)
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
if ["conv.weight", "conv.bias"] in output_block_list.values():
index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.weight"
]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.bias"
]
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
attentions = []
elif f"output_blocks.{i}.1.weight" in unet_state_dict:
# text_unet uses linear layers in place of upsamplers
shape = unet_state_dict[f"output_blocks.{i}.1.weight"].shape
if shape[0] != shape[1]:
continue
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.weight"] = unet_state_dict.pop(
f"output_blocks.{i}.1.weight"
)
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.bias"] = unet_state_dict.pop(
f"output_blocks.{i}.1.bias"
)
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
attentions = []
elif f"output_blocks.{i}.2.weight" in unet_state_dict:
# text_unet uses linear layers in place of upsamplers
shape = unet_state_dict[f"output_blocks.{i}.2.weight"].shape
if shape[0] != shape[1]:
continue
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.weight"] = unet_state_dict.pop(
f"output_blocks.{i}.2.weight"
)
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.bias"] = unet_state_dict.pop(
f"output_blocks.{i}.2.bias"
)
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {
"old": f"output_blocks.{i}.1",
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
else:
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
for path in resnet_0_paths:
old_path = ".".join(["output_blocks", str(i), path["old"]])
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
new_checkpoint[new_path] = unet_state_dict[old_path]
return new_checkpoint
def convert_vd_vae_checkpoint(checkpoint, config):
# extract state dict for VAE
vae_state_dict = {}
keys = list(checkpoint.keys())
for key in keys:
vae_state_dict[key] = checkpoint.get(key)
new_checkpoint = {}
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
# Retrieves the keys for the encoder down blocks only
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
down_blocks = {
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
# Retrieves the keys for the decoder up blocks only
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
up_blocks = {
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
}
for i in range(num_down_blocks):
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.weight"
)
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.bias"
)
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
conv_attn_to_linear(new_checkpoint)
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
resnets = [
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
]
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.weight"
]
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.bias"
]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
conv_attn_to_linear(new_checkpoint)
return new_checkpoint
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--unet_checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--vae_checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--optimus_checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--scheduler_type",
default="pndm",
type=str,
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancest', 'dpm']",
)
parser.add_argument(
"--extract_ema",
action="store_true",
help=(
"Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
" or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
),
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
args = parser.parse_args()
scheduler_config = SCHEDULER_CONFIG
num_train_timesteps = scheduler_config.timesteps
beta_start = scheduler_config.beta_linear_start
beta_end = scheduler_config.beta_linear_end
if args.scheduler_type == "pndm":
scheduler = PNDMScheduler(
beta_end=beta_end,
beta_schedule="scaled_linear",
beta_start=beta_start,
num_train_timesteps=num_train_timesteps,
skip_prk_steps=True,
steps_offset=1,
)
elif args.scheduler_type == "lms":
scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
elif args.scheduler_type == "euler":
scheduler = EulerDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
elif args.scheduler_type == "euler-ancestral":
scheduler = EulerAncestralDiscreteScheduler(
beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
)
elif args.scheduler_type == "dpm":
scheduler = DPMSolverMultistepScheduler(
beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
)
elif args.scheduler_type == "ddim":
scheduler = DDIMScheduler(
beta_start=beta_start,
beta_end=beta_end,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
else:
raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
# Convert the UNet2DConditionModel models.
if args.unet_checkpoint_path is not None:
# image UNet
image_unet_config = create_image_unet_diffusers_config(IMAGE_UNET_CONFIG)
checkpoint = torch.load(args.unet_checkpoint_path)
converted_image_unet_checkpoint = convert_vd_unet_checkpoint(
checkpoint, image_unet_config, unet_key="model.diffusion_model.unet_image.", extract_ema=args.extract_ema
)
image_unet = UNet2DConditionModel(**image_unet_config)
image_unet.load_state_dict(converted_image_unet_checkpoint)
# text UNet
text_unet_config = create_text_unet_diffusers_config(TEXT_UNET_CONFIG)
converted_text_unet_checkpoint = convert_vd_unet_checkpoint(
checkpoint, text_unet_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema
)
text_unet = UNetFlatConditionModel(**text_unet_config)
text_unet.load_state_dict(converted_text_unet_checkpoint)
# Convert the VAE model.
if args.vae_checkpoint_path is not None:
vae_config = create_vae_diffusers_config(AUTOENCODER_CONFIG)
checkpoint = torch.load(args.vae_checkpoint_path)
converted_vae_checkpoint = convert_vd_vae_checkpoint(checkpoint, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
image_feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
pipe = VersatileDiffusionPipeline(
scheduler=scheduler,
tokenizer=tokenizer,
image_feature_extractor=image_feature_extractor,
text_encoder=text_encoder,
image_encoder=image_encoder,
image_unet=image_unet,
text_unet=text_unet,
vae=vae,
)
pipe.save_pretrained(args.dump_path)

View File

@@ -39,8 +39,8 @@ import torch
import yaml
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from diffusers import VQDiffusionPipeline, VQDiffusionScheduler, VQModel
from diffusers.models.attention import Transformer2DModel
from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel
from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings
from transformers import CLIPTextModel, CLIPTokenizer
from yaml.loader import FullLoader
@@ -826,6 +826,20 @@ if __name__ == "__main__":
transformer_model, checkpoint
)
# classifier free sampling embeddings interlude
# The learned embeddings are stored on the transformer in the original VQ-diffusion. We store them on a separate
# model, so we pull them off the checkpoint before the checkpoint is deleted.
learnable_classifier_free_sampling_embeddings = diffusion_config.params.learnable_cf
if learnable_classifier_free_sampling_embeddings:
learned_classifier_free_sampling_embeddings_embeddings = checkpoint["transformer.empty_text_embed"]
else:
learned_classifier_free_sampling_embeddings_embeddings = None
# done classifier free sampling embeddings interlude
with tempfile.NamedTemporaryFile() as diffusers_transformer_checkpoint_file:
torch.save(diffusers_transformer_checkpoint, diffusers_transformer_checkpoint_file.name)
del diffusers_transformer_checkpoint
@@ -871,6 +885,31 @@ if __name__ == "__main__":
# done scheduler
# learned classifier free sampling embeddings
with init_empty_weights():
learned_classifier_free_sampling_embeddings_model = LearnedClassifierFreeSamplingEmbeddings(
learnable_classifier_free_sampling_embeddings,
hidden_size=text_encoder_model.config.hidden_size,
length=tokenizer_model.model_max_length,
)
learned_classifier_free_sampling_checkpoint = {
"embeddings": learned_classifier_free_sampling_embeddings_embeddings.float()
}
with tempfile.NamedTemporaryFile() as learned_classifier_free_sampling_checkpoint_file:
torch.save(learned_classifier_free_sampling_checkpoint, learned_classifier_free_sampling_checkpoint_file.name)
del learned_classifier_free_sampling_checkpoint
del learned_classifier_free_sampling_embeddings_embeddings
load_checkpoint_and_dispatch(
learned_classifier_free_sampling_embeddings_model,
learned_classifier_free_sampling_checkpoint_file.name,
device_map="auto",
)
# done learned classifier free sampling embeddings
print(f"saving VQ diffusion model, path: {args.dump_path}")
pipe = VQDiffusionPipeline(
@@ -878,6 +917,7 @@ if __name__ == "__main__":
transformer=transformer_model,
tokenizer=tokenizer_model,
text_encoder=text_encoder_model,
learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings_model,
scheduler=scheduler_model,
)
pipe.save_pretrained(args.dump_path)

View File

@@ -78,7 +78,7 @@ from setuptools import find_packages, setup
# 1. all dependencies should be listed here with their version requirements if any
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
_deps = [
"Pillow<10.0", # keep the PIL.Image.Resampling deprecation away
"Pillow", # keep the PIL.Image.Resampling deprecation away
"accelerate>=0.11.0",
"black==22.8",
"datasets",
@@ -97,6 +97,8 @@ _deps = [
"pytest",
"pytest-timeout",
"pytest-xdist",
"safetensors",
"sentencepiece>=0.1.91,!=0.1.92",
"scipy",
"regex!=2019.12.17",
"requests",
@@ -183,9 +185,11 @@ extras["test"] = deps_list(
"pytest",
"pytest-timeout",
"pytest-xdist",
"safetensors",
"sentencepiece",
"scipy",
"torchvision",
"transformers"
"transformers",
)
extras["torch"] = deps_list("torch", "accelerate")
@@ -210,7 +214,7 @@ install_requires = [
setup(
name="diffusers",
version="0.8.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.9.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="Diffusers",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",

View File

@@ -9,7 +9,7 @@ from .utils import (
)
__version__ = "0.8.0.dev0"
__version__ = "0.9.0"
from .configuration_utils import ConfigMixin
from .onnx_utils import OnnxRuntimeModel
@@ -46,6 +46,7 @@ if is_torch_available():
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
IPNDMScheduler,
KarrasVeScheduler,
PNDMScheduler,
@@ -65,12 +66,21 @@ else:
if is_torch_available() and is_transformers_available():
from .pipelines import (
AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline,
CycleDiffusionPipeline,
LDMTextToImagePipeline,
StableDiffusionImageVariationPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,
StableDiffusionPipeline,
StableDiffusionPipelineSafe,
StableDiffusionUpscalePipeline,
VersatileDiffusionDualGuidedPipeline,
VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline,
VQDiffusionPipeline,
)
else:
@@ -80,6 +90,7 @@ if is_torch_available() and is_transformers_available() and is_onnx_available():
from .pipelines import (
OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline,
OnnxStableDiffusionInpaintPipelineLegacy,
OnnxStableDiffusionPipeline,
StableDiffusionOnnxPipeline,
)

View File

@@ -29,7 +29,7 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R
from requests import HTTPError
from . import __version__
from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, logging
logger = logging.get_logger(__name__)
@@ -37,6 +37,38 @@ logger = logging.get_logger(__name__)
_re_configuration_file = re.compile(r"config\.(.*)\.json")
class FrozenDict(OrderedDict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
for key, value in self.items():
setattr(self, key, value)
self.__frozen = True
def __delitem__(self, *args, **kwargs):
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
def setdefault(self, *args, **kwargs):
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
def pop(self, *args, **kwargs):
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
def update(self, *args, **kwargs):
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
def __setattr__(self, name, value):
if hasattr(self, "__frozen") and self.__frozen:
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
super().__setattr__(name, value)
def __setitem__(self, name, value):
if hasattr(self, "__frozen") and self.__frozen:
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
super().__setitem__(name, value)
class ConfigMixin:
r"""
Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
@@ -48,21 +80,21 @@ class ConfigMixin:
- **config_name** (`str`) -- A filename under which the config should stored when calling
[`~ConfigMixin.save_config`] (should be overridden by parent class).
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
overridden by parent class).
- **_compatible_classes** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
`from_config` can be used from a class different than the one used to save the config (should be overridden
by parent class).
overridden by subclass).
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
subclass).
"""
config_name = None
ignore_for_config = []
_compatible_classes = []
has_compatibles = False
_deprecated_kwargs = []
def register_to_config(self, **kwargs):
if self.config_name is None:
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
kwargs["_class_name"] = self.__class__.__name__
kwargs["_diffusers_version"] = __version__
# Special case for `kwargs` used in deprecation warning added to schedulers
# TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
# or solve in a more general way.
@@ -104,9 +136,103 @@ class ConfigMixin:
logger.info(f"Configuration saved in {output_config_file}")
@classmethod
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
r"""
Instantiate a Python class from a pre-defined JSON-file.
Instantiate a Python class from a config dictionary
Parameters:
config (`Dict[str, Any]`):
A config dictionary from which the Python class will be instantiated. Make sure to only load
configuration files of compatible classes.
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
Whether kwargs that are not consumed by the Python class should be returned or not.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the Python class.
`**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually
overwrite same named arguments of `config`.
Examples:
```python
>>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
>>> # Download scheduler from huggingface.co and cache.
>>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
>>> # Instantiate DDIM scheduler class with same config as DDPM
>>> scheduler = DDIMScheduler.from_config(scheduler.config)
>>> # Instantiate PNDM scheduler class with same config as DDPM
>>> scheduler = PNDMScheduler.from_config(scheduler.config)
```
"""
# <===== TO BE REMOVED WITH DEPRECATION
# TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
if "pretrained_model_name_or_path" in kwargs:
config = kwargs.pop("pretrained_model_name_or_path")
if config is None:
raise ValueError("Please make sure to provide a config as the first positional argument.")
# ======>
if not isinstance(config, dict):
deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
if "Scheduler" in cls.__name__:
deprecation_message += (
f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
" Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
" be removed in v1.0.0."
)
elif "Model" in cls.__name__:
deprecation_message += (
f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
" instead. This functionality will be removed in v1.0.0."
)
deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
# Allow dtype to be specified on initialization
if "dtype" in unused_kwargs:
init_dict["dtype"] = unused_kwargs.pop("dtype")
# add possible deprecated kwargs
for deprecated_kwarg in cls._deprecated_kwargs:
if deprecated_kwarg in unused_kwargs:
init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
# Return model and optionally state and/or unused_kwargs
model = cls(**init_dict)
# make sure to also save config parameters that might be used for compatible classes
model.register_to_config(**hidden_dict)
# add hidden kwargs of compatible classes to unused_kwargs
unused_kwargs = {**unused_kwargs, **hidden_dict}
if return_unused_kwargs:
return (model, unused_kwargs)
else:
return model
@classmethod
def get_config_dict(cls, *args, **kwargs):
deprecation_message = (
f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
" removed in version v1.0.0"
)
deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
return cls.load_config(*args, **kwargs)
@classmethod
def load_config(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
r"""
Instantiate a Python class from a config dictionary
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
@@ -120,10 +246,6 @@ class ConfigMixin:
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
checkpoint with 3 labels).
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
@@ -161,33 +283,7 @@ class ConfigMixin:
use this method in a firewalled environment.
</Tip>
"""
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
# Allow dtype to be specified on initialization
if "dtype" in unused_kwargs:
init_dict["dtype"] = unused_kwargs.pop("dtype")
# Return model and optionally state and/or unused_kwargs
model = cls(**init_dict)
return_tuple = (model,)
# Flax schedulers have a state, so return it.
if cls.__name__.startswith("Flax") and hasattr(model, "create_state") and getattr(model, "has_state", False):
state = model.create_state()
return_tuple += (state,)
if return_unused_kwargs:
return return_tuple + (unused_kwargs,)
else:
return return_tuple if len(return_tuple) > 1 else model
@classmethod
def get_config_dict(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
@@ -283,6 +379,9 @@ class ConfigMixin:
except (json.JSONDecodeError, UnicodeDecodeError):
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
if return_unused_kwargs:
return config_dict, kwargs
return config_dict
@staticmethod
@@ -291,6 +390,9 @@ class ConfigMixin:
@classmethod
def extract_init_dict(cls, config_dict, **kwargs):
# 0. Copy origin config dict
original_dict = {k: v for k, v in config_dict.items()}
# 1. Retrieve expected config attributes from __init__ signature
expected_keys = cls._get_init_keys(cls)
expected_keys.remove("self")
@@ -310,10 +412,11 @@ class ConfigMixin:
# load diffusers library to import compatible and original scheduler
diffusers_library = importlib.import_module(__name__.split(".")[0])
# remove attributes from compatible classes that orig cannot expect
compatible_classes = [getattr(diffusers_library, c, None) for c in cls._compatible_classes]
# filter out None potentially undefined dummy classes
compatible_classes = [c for c in compatible_classes if c is not None]
if cls.has_compatibles:
compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
else:
compatible_classes = []
expected_keys_comp_cls = set()
for c in compatible_classes:
expected_keys_c = cls._get_init_keys(c)
@@ -364,7 +467,10 @@ class ConfigMixin:
# 6. Define unused keyword arguments
unused_kwargs = {**config_dict, **kwargs}
return init_dict, unused_kwargs
# 7. Define "hidden" config parameters that were saved for compatible classes
hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
return init_dict, unused_kwargs, hidden_config_dict
@classmethod
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
@@ -377,6 +483,12 @@ class ConfigMixin:
@property
def config(self) -> Dict[str, Any]:
"""
Returns the config of the class as a frozen dictionary
Returns:
`Dict[str, Any]`: Config of the class.
"""
return self._internal_dict
def to_json_string(self) -> str:
@@ -387,6 +499,9 @@ class ConfigMixin:
`str`: String containing all the attributes that make up this configuration instance in JSON format.
"""
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
config_dict["_class_name"] = self.__class__.__name__
config_dict["_diffusers_version"] = __version__
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
@@ -401,38 +516,6 @@ class ConfigMixin:
writer.write(self.to_json_string())
class FrozenDict(OrderedDict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
for key, value in self.items():
setattr(self, key, value)
self.__frozen = True
def __delitem__(self, *args, **kwargs):
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
def setdefault(self, *args, **kwargs):
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
def pop(self, *args, **kwargs):
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
def update(self, *args, **kwargs):
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
def __setattr__(self, name, value):
if hasattr(self, "__frozen") and self.__frozen:
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
super().__setattr__(name, value)
def __setitem__(self, name, value):
if hasattr(self, "__frozen") and self.__frozen:
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
super().__setitem__(name, value)
def register_to_config(init):
r"""
Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
@@ -446,7 +529,7 @@ def register_to_config(init):
def inner_init(self, *args, **kwargs):
# Ignore private kwargs in the init.
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
init(self, *args, **init_kwargs)
config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
if not isinstance(self, ConfigMixin):
raise RuntimeError(
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
@@ -471,7 +554,9 @@ def register_to_config(init):
if k not in ignore and k not in new_kwargs
}
)
new_kwargs = {**config_init_kwargs, **new_kwargs}
getattr(self, "register_to_config")(**new_kwargs)
init(self, *args, **init_kwargs)
return inner_init
@@ -488,7 +573,7 @@ def flax_register_to_config(cls):
)
# Ignore private kwargs in the init. Retrieve all passed attributes
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
init_kwargs = {k: v for k, v in kwargs.items()}
# Retrieve default values
fields = dataclasses.fields(self)

View File

@@ -2,7 +2,7 @@
# 1. modify the `_deps` dict in setup.py
# 2. run `make deps_table_update``
deps = {
"Pillow": "Pillow<10.0",
"Pillow": "Pillow",
"accelerate": "accelerate>=0.11.0",
"black": "black==22.8",
"datasets": "datasets",
@@ -21,6 +21,8 @@ deps = {
"pytest": "pytest",
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
"safetensors": "safetensors",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
"scipy": "scipy",
"regex": "regex!=2019.12.17",
"requests": "requests",

View File

@@ -89,6 +89,7 @@ class ValueGuidedRLPipeline(DiffusionPipeline):
x = x + scale * grad
x = self.reset_x0(x, conditions, self.action_dim)
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
# TODO: set prediction_type when instantiating the model
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
# apply conditions to the trajectory

View File

@@ -332,7 +332,7 @@ class FlaxModelMixin:
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
raise EnvironmentError(
f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
" using `from_pt=True`."
" using `from_pt=True`."
)
else:
raise EnvironmentError(

View File

@@ -30,8 +30,10 @@ from .utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
is_accelerate_available,
is_safetensors_available,
is_torch_version,
logging,
)
@@ -51,6 +53,9 @@ if is_accelerate_available():
from accelerate.utils import set_module_tensor_to_device
from accelerate.utils.versions import is_torch_version
if is_safetensors_available():
import safetensors
def get_parameter_device(parameter: torch.nn.Module):
try:
@@ -84,10 +89,13 @@ def get_parameter_dtype(parameter: torch.nn.Module):
def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
"""
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
Reads a checkpoint file, returning properly formatted errors if they arise.
"""
try:
return torch.load(checkpoint_file, map_location="cpu")
if os.path.basename(checkpoint_file) == WEIGHTS_NAME:
return torch.load(checkpoint_file, map_location="cpu")
else:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
except Exception as e:
try:
with open(checkpoint_file) as f:
@@ -104,7 +112,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
) from e
except (UnicodeDecodeError, ValueError):
raise OSError(
f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
f"at '{checkpoint_file}'. "
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
)
@@ -332,7 +340,7 @@ class ModelMixin(torch.nn.Module):
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warn(
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
@@ -375,80 +383,44 @@ class ModelMixin(torch.nn.Module):
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# Load model
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
elif subfolder is not None and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
):
model_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
else:
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
)
else:
model_file = None
if is_safetensors_available():
try:
# Load from URL or cache if already cached
model_file = hf_hub_download(
model_file = _get_model_file(
pretrained_model_name_or_path,
filename=WEIGHTS_NAME,
weights_name=SAFETENSORS_WEIGHTS_NAME,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}."
)
except HTTPError as err:
raise EnvironmentError(
"There was a specific connection error when trying to load"
f" {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {WEIGHTS_NAME} or"
" \nCheckout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {WEIGHTS_NAME}"
)
# restore default dtype
except:
pass
if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=WEIGHTS_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
if low_cpu_mem_usage:
# Instantiate model with empty weights
with accelerate.init_empty_weights():
model, unused_kwargs = cls.from_config(
config, unused_kwargs = cls.load_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
@@ -462,6 +434,7 @@ class ModelMixin(torch.nn.Module):
device_map=device_map,
**kwargs,
)
model = cls.from_config(config, **unused_kwargs)
# if device_map is Non,e load the state dict on move the params from meta device to the cpu
if device_map is None:
@@ -482,7 +455,7 @@ class ModelMixin(torch.nn.Module):
"error_msgs": [],
}
else:
model, unused_kwargs = cls.from_config(
config, unused_kwargs = cls.load_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
@@ -496,6 +469,7 @@ class ModelMixin(torch.nn.Module):
device_map=device_map,
**kwargs,
)
model = cls.from_config(config, **unused_kwargs)
state_dict = load_state_dict(model_file)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
@@ -689,3 +663,88 @@ def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
return unwrap_model(model.module)
else:
return model
def _get_model_file(
pretrained_model_name_or_path,
*,
weights_name,
subfolder,
cache_dir,
force_download,
proxies,
resume_download,
local_files_only,
use_auth_token,
user_agent,
revision,
):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
# Load from a PyTorch checkpoint
model_file = os.path.join(pretrained_model_name_or_path, weights_name)
return model_file
elif subfolder is not None and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
):
model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
return model_file
else:
raise EnvironmentError(
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
)
else:
try:
# Load from URL or cache if already cached
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=weights_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
)
return model_file
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {weights_name} or"
" \nCheckout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {weights_name}"
)

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from dataclasses import dataclass
from typing import Optional
@@ -98,8 +99,11 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
num_vector_embeds: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
):
super().__init__()
self.use_linear_projection = use_linear_projection
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
@@ -125,7 +129,10 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim)
else:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
@@ -151,6 +158,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
)
for d in range(num_layers)
]
@@ -158,7 +166,10 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
# 4. Define output layers
if self.is_input_continuous:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
if use_linear_projection:
self.proj_out = nn.Linear(in_channels, inner_dim)
else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
@@ -190,10 +201,16 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if self.is_input_continuous:
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
hidden_states = self.proj_in(hidden_states)
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
@@ -203,8 +220,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
# 3. Output
if self.is_input_continuous:
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
hidden_states = self.proj_out(hidden_states)
if not self.use_linear_projection:
hidden_states = (
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
)
hidden_states = self.proj_out(hidden_states)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = (
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
)
output = hidden_states + residual
elif self.is_input_vectorized:
hidden_states = self.norm_out(hidden_states)
@@ -284,22 +310,52 @@ class AttentionBlock(nn.Module):
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
# transpose
query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)
scale = 1 / math.sqrt(self.channels / self.num_heads)
# get scores
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm
if self.num_heads > 1:
query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)
# TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors?
# or reformulate this into a 3D problem?
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * scale
else:
query_states, key_states, value_states = query_proj, key_proj, value_proj
attention_scores = torch.baddbmm(
torch.empty(
query_states.shape[0],
query_states.shape[1],
key_states.shape[1],
dtype=query_states.dtype,
device=query_states.device,
),
query_states,
key_states.transpose(-1, -2),
beta=0,
alpha=scale,
)
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
# compute attention output
hidden_states = torch.matmul(attention_probs, value_states)
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
hidden_states = hidden_states.view(new_hidden_states_shape)
if self.num_heads > 1:
# TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors?
# or reformulate this into a 3D problem?
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
hidden_states = torch.matmul(attention_probs, value_states)
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
hidden_states = hidden_states.view(new_hidden_states_shape)
else:
hidden_states = torch.bmm(attention_probs, value_states)
# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
@@ -337,14 +393,17 @@ class BasicTransformerBlock(nn.Module):
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.attn1 = CrossAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.attn2 = CrossAttention(
@@ -366,6 +425,16 @@ class BasicTransformerBlock(nn.Module):
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
# if xformers is installed try to use memory_efficient_attention by default
if is_xformers_available():
try:
self._set_use_memory_efficient_attention_xformers(True)
except Exception as e:
warnings.warn(
"Could not enable memory efficient attention. Make sure xformers is installed"
f" correctly and a GPU is available: {e}"
)
def _set_attention_slice(self, slice_size):
self.attn1._slice_size = slice_size
self.attn2._slice_size = slice_size
@@ -401,7 +470,11 @@ class BasicTransformerBlock(nn.Module):
norm_hidden_states = (
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
)
hidden_states = self.attn1(norm_hidden_states) + hidden_states
if self.only_cross_attention:
hidden_states = self.attn1(norm_hidden_states, context) + hidden_states
else:
hidden_states = self.attn1(norm_hidden_states) + hidden_states
# 2. Cross-Attention
norm_hidden_states = (
@@ -507,19 +580,17 @@ class CrossAttention(nn.Module):
return hidden_states
def _attention(self, query, key, value):
# TODO: use baddbmm for better performance
if query.device.type == "mps":
# Better performance on mps (~20-25%)
attention_scores = torch.einsum("b i d, b j d -> b i j", query, key) * self.scale
else:
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
attention_probs = attention_scores.softmax(dim=-1)
# compute attention output
if query.device.type == "mps":
hidden_states = torch.einsum("b i j, b j d -> b i d", attention_probs, value)
else:
hidden_states = torch.matmul(attention_probs, value)
hidden_states = torch.bmm(attention_probs, value)
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
@@ -534,21 +605,15 @@ class CrossAttention(nn.Module):
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
if query.device.type == "mps":
# Better performance on mps (~20-25%)
attn_slice = (
torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx])
* self.scale
)
else:
attn_slice = (
torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
) # TODO: use baddbmm for better performance
attn_slice = torch.baddbmm(
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query[start_idx:end_idx],
key[start_idx:end_idx].transpose(-1, -2),
beta=0,
alpha=self.scale,
)
attn_slice = attn_slice.softmax(dim=-1)
if query.device.type == "mps":
attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
else:
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
@@ -666,3 +731,129 @@ class AdaLayerNorm(nn.Module):
scale, shift = torch.chunk(emb, 2)
x = self.norm(x) * (1 + scale) + shift
return x
class DualTransformer2DModel(nn.Module):
"""
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
Pass if the input is continuous. The number of channels in the input and output.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
`ImagePositionalEmbeddings`.
num_vector_embeds (`int`, *optional*):
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
Includes the class for the masked latent pixel.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
up to but not more than steps than `num_embeds_ada_norm`.
attention_bias (`bool`, *optional*):
Configure if the TransformerBlocks' attention should contain a bias parameter.
"""
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
num_vector_embeds: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
):
super().__init__()
self.transformers = nn.ModuleList(
[
Transformer2DModel(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
in_channels=in_channels,
num_layers=num_layers,
dropout=dropout,
norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_bias=attention_bias,
sample_size=sample_size,
num_vector_embeds=num_vector_embeds,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
)
for _ in range(2)
]
)
# Variables that can be set by a pipeline:
# The ratio of transformer1 to transformer2's output states to be combined during inference
self.mix_ratio = 0.5
# The shape of `encoder_hidden_states` is expected to be
# `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
self.condition_lengths = [77, 257]
# Which transformer to use to encode which condition.
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
self.transformer_index_for_condition = [1, 0]
def forward(self, hidden_states, encoder_hidden_states, timestep=None, return_dict: bool = True):
"""
Args:
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
hidden_states
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.long`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Returns:
[`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
tensor.
"""
input_states = hidden_states
encoded_states = []
tokens_start = 0
for i in range(2):
# for each of the two transformers, pass the corresponding condition tokens
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
transformer_index = self.transformer_index_for_condition[i]
encoded_state = self.transformers[transformer_index](input_states, condition_state, timestep, return_dict)[
0
]
encoded_states.append(encoded_state - input_states)
tokens_start += self.condition_lengths[i]
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
output_states = output_states + input_states
if not return_dict:
return (output_states,)
return Transformer2DModelOutput(sample=output_states)
def _set_attention_slice(self, slice_size):
for transformer in self.transformers:
transformer._set_attention_slice(slice_size)
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for transformer in self.transformers:
transformer._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

View File

@@ -104,6 +104,8 @@ class FlaxBasicTransformerBlock(nn.Module):
Hidden states dimension inside each head
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
only_cross_attention (`bool`, defaults to `False`):
Whether to only apply cross attention.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
@@ -111,10 +113,11 @@ class FlaxBasicTransformerBlock(nn.Module):
n_heads: int
d_head: int
dropout: float = 0.0
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
# self attention
# self attention (or cross_attention if only_cross_attention is True)
self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
# cross attention
self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
@@ -126,7 +129,10 @@ class FlaxBasicTransformerBlock(nn.Module):
def __call__(self, hidden_states, context, deterministic=True):
# self attention
residual = hidden_states
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
if self.only_cross_attention:
hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
else:
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
hidden_states = hidden_states + residual
# cross attention
@@ -159,6 +165,8 @@ class FlaxTransformer2DModel(nn.Module):
Number of transformers block
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
use_linear_projection (`bool`, defaults to `False`): tbd
only_cross_attention (`bool`, defaults to `False`): tbd
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
@@ -167,49 +175,70 @@ class FlaxTransformer2DModel(nn.Module):
d_head: int
depth: int = 1
dropout: float = 0.0
use_linear_projection: bool = False
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
inner_dim = self.n_heads * self.d_head
self.proj_in = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
if self.use_linear_projection:
self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
else:
self.proj_in = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
self.transformer_blocks = [
FlaxBasicTransformerBlock(inner_dim, self.n_heads, self.d_head, dropout=self.dropout, dtype=self.dtype)
FlaxBasicTransformerBlock(
inner_dim,
self.n_heads,
self.d_head,
dropout=self.dropout,
only_cross_attention=self.only_cross_attention,
dtype=self.dtype,
)
for _ in range(self.depth)
]
self.proj_out = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
if self.use_linear_projection:
self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
else:
self.proj_out = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
def __call__(self, hidden_states, context, deterministic=True):
batch, height, width, channels = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.reshape(batch, height * width, channels)
if self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height * width, channels)
hidden_states = self.proj_in(hidden_states)
else:
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.reshape(batch, height * width, channels)
for transformer_block in self.transformer_blocks:
hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
hidden_states = hidden_states.reshape(batch, height, width, channels)
if self.use_linear_projection:
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, channels)
else:
hidden_states = hidden_states.reshape(batch, height, width, channels)
hidden_states = self.proj_out(hidden_states)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states + residual
return hidden_states

View File

@@ -43,8 +43,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
implements for all the model (such as downloading or saving, etc.)
Parameters:
sample_size (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
Input sample size.
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample.
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
@@ -71,7 +71,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
sample_size: Optional[int] = None,
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
in_channels: int = 3,
out_channels: int = 3,
center_input_sample: bool = False,
@@ -175,7 +175,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def forward(
self,
@@ -209,6 +209,11 @@ class UNet2DModel(ModelMixin, ConfigMixin):
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
# 2. pre-process
@@ -242,9 +247,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
sample = upsample_block(sample, res_samples, emb)
# 6. post-process
# make sure hidden states is in float32
# when running in half-precision
sample = self.conv_norm_out(sample.float()).type(sample.dtype)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)

View File

@@ -15,7 +15,7 @@ import numpy as np
import torch
from torch import nn
from .attention import AttentionBlock, Transformer2DModel
from .attention import AttentionBlock, DualTransformer2DModel, Transformer2DModel
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
@@ -32,6 +32,9 @@ def get_down_block(
resnet_groups=None,
cross_attention_dim=None,
downsample_padding=None,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock2D":
@@ -74,6 +77,9 @@ def get_down_block(
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
elif down_block_type == "SkipDownBlock2D":
return SkipDownBlock2D(
@@ -137,6 +143,9 @@ def get_up_block(
attn_num_head_channels,
resnet_groups=None,
cross_attention_dim=None,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock2D":
@@ -166,6 +175,9 @@ def get_up_block(
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
elif up_block_type == "AttnUpBlock2D":
return AttnUpBlock2D(
@@ -242,7 +254,6 @@ class UNetMidBlock2D(nn.Module):
attn_num_head_channels=1,
attention_type="default",
output_scale_factor=1.0,
**kwargs,
):
super().__init__()
@@ -322,7 +333,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
attention_type="default",
output_scale_factor=1.0,
cross_attention_dim=1280,
**kwargs,
dual_cross_attention=False,
use_linear_projection=False,
):
super().__init__()
@@ -348,16 +360,29 @@ class UNetMidBlock2DCrossAttn(nn.Module):
attentions = []
for _ in range(num_layers):
attentions.append(
Transformer2DModel(
attn_num_head_channels,
in_channels // attn_num_head_channels,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
attn_num_head_channels,
in_channels // attn_num_head_channels,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
)
)
else:
attentions.append(
DualTransformer2DModel(
attn_num_head_channels,
in_channels // attn_num_head_channels,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
)
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
@@ -377,15 +402,17 @@ class UNetMidBlock2DCrossAttn(nn.Module):
self.resnets = nn.ModuleList(resnets)
def set_attention_slice(self, slice_size):
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
head_dims = self.attn_num_head_channels
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
raise ValueError(
f"Make sure slice_size {slice_size} is a divisor of "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention: {head_dims}"
)
if slice_size is not None and slice_size > self.attn_num_head_channels:
if slice_size is not None and slice_size > min(head_dims):
raise ValueError(
f"Chunk_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
f"slice_size {slice_size} has to be smaller or equal to "
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
)
for attn in self.attentions:
@@ -505,6 +532,9 @@ class CrossAttnDownBlock2D(nn.Module):
output_scale_factor=1.0,
downsample_padding=1,
add_downsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
super().__init__()
resnets = []
@@ -529,16 +559,30 @@ class CrossAttnDownBlock2D(nn.Module):
pre_norm=resnet_pre_norm,
)
)
attentions.append(
Transformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
)
else:
attentions.append(
DualTransformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
@@ -556,15 +600,17 @@ class CrossAttnDownBlock2D(nn.Module):
self.gradient_checkpointing = False
def set_attention_slice(self, slice_size):
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
head_dims = self.attn_num_head_channels
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
raise ValueError(
f"Make sure slice_size {slice_size} is a divisor of "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention: {head_dims}"
)
if slice_size is not None and slice_size > self.attn_num_head_channels:
if slice_size is not None and slice_size > min(head_dims):
raise ValueError(
f"Chunk_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
f"slice_size {slice_size} has to be smaller or equal to "
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
)
for attn in self.attentions:
@@ -1089,6 +1135,9 @@ class CrossAttnUpBlock2D(nn.Module):
attention_type="default",
output_scale_factor=1.0,
add_upsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
super().__init__()
resnets = []
@@ -1115,16 +1164,30 @@ class CrossAttnUpBlock2D(nn.Module):
pre_norm=resnet_pre_norm,
)
)
attentions.append(
Transformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
)
else:
attentions.append(
DualTransformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
@@ -1136,15 +1199,17 @@ class CrossAttnUpBlock2D(nn.Module):
self.gradient_checkpointing = False
def set_attention_slice(self, slice_size):
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
head_dims = self.attn_num_head_channels
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
raise ValueError(
f"Make sure slice_size {slice_size} is a divisor of "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention: {head_dims}"
)
if slice_size is not None and slice_size > self.attn_num_head_channels:
if slice_size is not None and slice_size > min(head_dims):
raise ValueError(
f"Chunk_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
f"slice_size {slice_size} has to be smaller or equal to "
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
)
for attn in self.attentions:

View File

@@ -46,6 +46,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
num_layers: int = 1
attn_num_head_channels: int = 1
add_downsample: bool = True
use_linear_projection: bool = False
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
@@ -68,6 +70,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
n_heads=self.attn_num_head_channels,
d_head=self.out_channels // self.attn_num_head_channels,
depth=1,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
dtype=self.dtype,
)
attentions.append(attn_block)
@@ -178,6 +182,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
num_layers: int = 1
attn_num_head_channels: int = 1
add_upsample: bool = True
use_linear_projection: bool = False
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
@@ -201,6 +207,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
n_heads=self.attn_num_head_channels,
d_head=self.out_channels // self.attn_num_head_channels,
depth=1,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
dtype=self.dtype,
)
attentions.append(attn_block)
@@ -310,6 +318,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
dropout: float = 0.0
num_layers: int = 1
attn_num_head_channels: int = 1
use_linear_projection: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
@@ -331,6 +340,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
n_heads=self.attn_num_head_channels,
d_head=self.in_channels // self.attn_num_head_channels,
depth=1,
use_linear_projection=self.use_linear_projection,
dtype=self.dtype,
)
attentions.append(attn_block)

View File

@@ -56,11 +56,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
implements for all the models (such as downloading or saving, etc.)
Parameters:
sample_size (`int`, *optional*): The size of the input sample.
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
@@ -97,6 +98,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
"DownBlock2D",
),
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
@@ -105,7 +107,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
norm_num_groups: int = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
attention_head_dim: int = 8,
attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
num_class_embeds: Optional[int] = None,
):
super().__init__()
@@ -121,10 +126,20 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
# class embedding
if num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
@@ -143,8 +158,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim,
attn_num_head_channels=attention_head_dim[i],
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
)
self.down_blocks.append(down_block)
@@ -157,8 +175,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default",
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim,
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
)
# count how many layers upsample the images
@@ -166,6 +186,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1
@@ -193,7 +215,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim,
attn_num_head_channels=reversed_attention_head_dim[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
@@ -201,18 +226,20 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def set_attention_slice(self, slice_size):
if slice_size is not None and self.config.attention_head_dim % slice_size != 0:
head_dims = self.config.attention_head_dim
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
raise ValueError(
f"Make sure slice_size {slice_size} is a divisor of "
f"the number of heads used in cross_attention {self.config.attention_head_dim}"
f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention: {head_dims}"
)
if slice_size is not None and slice_size > self.config.attention_head_dim:
if slice_size is not None and slice_size > min(head_dims):
raise ValueError(
f"Chunk_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.config.attention_head_dim}"
f"slice_size {slice_size} has to be smaller or equal to "
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
)
for block in self.down_blocks:
@@ -245,6 +272,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
r"""
@@ -297,6 +325,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
if self.config.num_class_embeds is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# 2. pre-process
sample = self.conv_in(sample)

View File

@@ -79,7 +79,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
attention_head_dim (`int`, *optional*, defaults to 8):
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
The dimension of the attention heads.
cross_attention_dim (`int`, *optional*, defaults to 768):
The dimension of the cross attention features.
@@ -97,11 +97,13 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
"DownBlock2D",
)
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
only_cross_attention: Union[bool, Tuple[bool]] = False
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
layers_per_block: int = 2
attention_head_dim: int = 8
attention_head_dim: Union[int, Tuple[int]] = 8
cross_attention_dim: int = 1280
dropout: float = 0.0
use_linear_projection: bool = False
dtype: jnp.dtype = jnp.float32
freq_shift: int = 0
@@ -134,6 +136,14 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
self.time_proj = FlaxTimesteps(block_out_channels[0], freq_shift=self.config.freq_shift)
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
only_cross_attention = self.only_cross_attention
if isinstance(only_cross_attention, bool):
only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
attention_head_dim = self.attention_head_dim
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(self.down_block_types)
# down
down_blocks = []
output_channel = block_out_channels[0]
@@ -148,8 +158,10 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
out_channels=output_channel,
dropout=self.dropout,
num_layers=self.layers_per_block,
attn_num_head_channels=self.attention_head_dim,
attn_num_head_channels=attention_head_dim[i],
add_downsample=not is_final_block,
use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i],
dtype=self.dtype,
)
else:
@@ -169,13 +181,16 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
dropout=self.dropout,
attn_num_head_channels=self.attention_head_dim,
attn_num_head_channels=attention_head_dim[-1],
use_linear_projection=self.use_linear_projection,
dtype=self.dtype,
)
# up
up_blocks = []
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(self.up_block_types):
prev_output_channel = output_channel
@@ -190,9 +205,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
out_channels=output_channel,
prev_output_channel=prev_output_channel,
num_layers=self.layers_per_block + 1,
attn_num_head_channels=self.attention_head_dim,
attn_num_head_channels=reversed_attention_head_dim[i],
add_upsample=not is_final_block,
dropout=self.dropout,
use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i],
dtype=self.dtype,
)
else:
@@ -230,9 +247,9 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
) -> Union[FlaxUNet2DConditionOutput, Tuple]:
r"""
Args:
sample (`jnp.ndarray`): (channel, height, width) noisy inputs tensor
sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
timestep (`jnp.ndarray` or `float` or `int`): timesteps
encoder_hidden_states (`jnp.ndarray`): (channel, height, width) encoder hidden states
encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
plain tuple.

View File

@@ -565,6 +565,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
self.use_slicing = False
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
h = self.encoder(x)
@@ -576,7 +577,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
return AutoencoderKLOutput(latent_dist=posterior)
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
z = self.post_quant_conv(z)
dec = self.decoder(z)
@@ -585,6 +586,34 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
return DecoderOutput(sample=dec)
def enable_slicing(self):
r"""
Enable sliced VAE decoding.
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self,
sample: torch.FloatTensor,

View File

@@ -47,7 +47,7 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES = {
"diffusers": {
"FlaxModelMixin": ["save_pretrained", "from_pretrained"],
"FlaxSchedulerMixin": ["save_config", "from_config"],
"FlaxSchedulerMixin": ["save_pretrained", "from_pretrained"],
"FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"],
},
"transformers": {
@@ -280,7 +280,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
>>> from diffusers import FlaxDPMSolverMultistepScheduler
>>> model_id = "runwayml/stable-diffusion-v1-5"
>>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_config(
>>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
... model_id,
... subfolder="scheduler",
... )
@@ -303,7 +303,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path):
config_dict = cls.get_config_dict(
config_dict = cls.load_config(
pretrained_model_name_or_path,
cache_dir=cache_dir,
resume_download=resume_download,
@@ -317,8 +317,8 @@ class FlaxDiffusionPipeline(ConfigMixin):
allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name]
# make sure we don't download PyTorch weights
ignore_patterns = "*.bin"
# make sure we don't download PyTorch weights, unless when using from_pt
ignore_patterns = "*.bin" if not from_pt else []
if cls != FlaxDiffusionPipeline:
requested_pipeline_class = cls.__name__
@@ -349,7 +349,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
else:
cached_folder = pretrained_model_name_or_path
config_dict = cls.get_config_dict(cached_folder)
config_dict = cls.load_config(cached_folder)
# 2. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
@@ -370,7 +370,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {}
@@ -411,13 +411,13 @@ class FlaxDiffusionPipeline(ConfigMixin):
f" {expected_class_obj}"
)
elif passed_class_obj[name] is None:
logger.warn(
logger.warning(
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
f" that this might lead to problems when using {pipeline_class} and is not recommended."
)
sub_model_should_be_defined = False
else:
logger.warn(
logger.warning(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
" has the correct type"
)

View File

@@ -18,6 +18,7 @@ import importlib
import inspect
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import numpy as np
@@ -25,7 +26,7 @@ import torch
import diffusers
import PIL
from huggingface_hub import snapshot_download
from huggingface_hub import model_info, snapshot_download
from packaging import version
from PIL import Image
from tqdm.auto import tqdm
@@ -43,6 +44,7 @@ from .utils import (
BaseOutput,
deprecate,
is_accelerate_available,
is_safetensors_available,
is_torch_version,
is_transformers_available,
logging,
@@ -57,6 +59,7 @@ if is_transformers_available():
INDEX_FILE = "diffusion_pytorch_model.bin"
CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
DUMMY_MODULES_FOLDER = "diffusers.utils"
TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils"
logger = logging.get_logger(__name__)
@@ -65,7 +68,7 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES = {
"diffusers": {
"ModelMixin": ["save_pretrained", "from_pretrained"],
"SchedulerMixin": ["save_config", "from_config"],
"SchedulerMixin": ["save_pretrained", "from_pretrained"],
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
},
@@ -77,6 +80,9 @@ LOADABLE_CLASSES = {
"ProcessorMixin": ["save_pretrained", "from_pretrained"],
"ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
},
"onnxruntime.training": {
"ORTModule": ["save_pretrained", "from_pretrained"],
},
}
ALL_IMPORTABLE_CLASSES = {}
@@ -112,6 +118,23 @@ class AudioPipelineOutput(BaseOutput):
audios: np.ndarray
def is_safetensors_compatible(info) -> bool:
filenames = set(sibling.rfilename for sibling in info.siblings)
pt_filenames = set(filename for filename in filenames if filename.endswith(".bin"))
is_safetensors_compatible = any(file.endswith(".safetensors") for file in filenames)
for pt_filename in pt_filenames:
prefix, raw = os.path.split(pt_filename)
if raw == "pytorch_model.bin":
# transformers specific
sf_filename = os.path.join(prefix, "model.safetensors")
else:
sf_filename = pt_filename[: -len(".bin")] + ".safetensors"
if is_safetensors_compatible and sf_filename not in filenames:
logger.warning(f"{sf_filename} not found")
is_safetensors_compatible = False
return is_safetensors_compatible
class DiffusionPipeline(ConfigMixin):
r"""
Base class for all models.
@@ -124,10 +147,13 @@ class DiffusionPipeline(ConfigMixin):
Class attributes:
- **config_name** ([`str`]) -- name of the config file that will store the class and module names of all
- **config_name** (`str`) -- name of the config file that will store the class and module names of all
components of the diffusion pipeline.
- **_optional_components** (List[`str`]) -- list of all components that are optional so they don't have to be
passed for the pipeline to function (should be overridden by subclasses).
"""
config_name = "model_index.json"
_optional_components = []
def register_modules(self, **kwargs):
# import it here to avoid circular import
@@ -179,12 +205,19 @@ class DiffusionPipeline(ConfigMixin):
model_index_dict.pop("_diffusers_version")
model_index_dict.pop("_module", None)
expected_modules, optional_kwargs = self._get_signature_keys(self)
def is_saveable_module(name, value):
if name not in expected_modules:
return False
if name in self._optional_components and value[0] is None:
return False
return True
model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}
for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name)
if sub_model is None:
# edge case for saving a pipeline with safety_checker=None
continue
model_cls = sub_model.__class__
save_method_name = None
@@ -207,7 +240,7 @@ class DiffusionPipeline(ConfigMixin):
if torch_device is None:
return self
module_names, _ = self.extract_init_dict(dict(self.config))
module_names, _, _ = self.extract_init_dict(dict(self.config))
for name in module_names.keys():
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
@@ -228,7 +261,7 @@ class DiffusionPipeline(ConfigMixin):
Returns:
`torch.device`: The torch device on which the pipeline is located.
"""
module_names, _ = self.extract_init_dict(dict(self.config))
module_names, _, _ = self.extract_init_dict(dict(self.config))
for name in module_names.keys():
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
@@ -377,11 +410,11 @@ class DiffusionPipeline(ConfigMixin):
>>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> # Download pipeline, but overwrite scheduler
>>> # Use a different scheduler
>>> from diffusers import LMSDiscreteScheduler
>>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler)
>>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
>>> pipeline.scheduler = scheduler
```
"""
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
@@ -400,7 +433,7 @@ class DiffusionPipeline(ConfigMixin):
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warn(
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
@@ -428,7 +461,7 @@ class DiffusionPipeline(ConfigMixin):
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path):
config_dict = cls.get_config_dict(
config_dict = cls.load_config(
pretrained_model_name_or_path,
cache_dir=cache_dir,
resume_download=resume_download,
@@ -444,7 +477,7 @@ class DiffusionPipeline(ConfigMixin):
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
# make sure we don't download flax weights
ignore_patterns = "*.msgpack"
ignore_patterns = ["*.msgpack"]
if custom_pipeline is not None:
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
@@ -458,6 +491,15 @@ class DiffusionPipeline(ConfigMixin):
user_agent["custom_pipeline"] = custom_pipeline
user_agent = http_user_agent(user_agent)
if is_safetensors_available():
info = model_info(
pretrained_model_name_or_path,
use_auth_token=use_auth_token,
revision=revision,
)
if is_safetensors_compatible(info):
ignore_patterns.append("*.bin")
# download all allow_patterns
cached_folder = snapshot_download(
pretrained_model_name_or_path,
@@ -474,13 +516,21 @@ class DiffusionPipeline(ConfigMixin):
else:
cached_folder = pretrained_model_name_or_path
config_dict = cls.get_config_dict(cached_folder)
config_dict = cls.load_config(cached_folder)
# 2. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
if custom_pipeline is not None:
if custom_pipeline.endswith(".py"):
path = Path(custom_pipeline)
# decompose into folder & file
file_name = path.name
custom_pipeline = path.parent.absolute()
else:
file_name = CUSTOM_PIPELINE_FILE_NAME
pipeline_class = get_class_from_dynamic_module(
custom_pipeline, module_file=CUSTOM_PIPELINE_FILE_NAME, cache_dir=custom_pipeline
custom_pipeline, module_file=file_name, cache_dir=custom_pipeline
)
elif cls != DiffusionPipeline:
pipeline_class = cls
@@ -510,38 +560,47 @@ class DiffusionPipeline(ConfigMixin):
# some modules can be passed directly to the init
# in this case they are already instantiated in `kwargs`
# extract them here
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"])
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
init_dict, unused_kwargs = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
# define init kwargs
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
# remove `null` components
def load_module(name, value):
if value[0] is None:
return False
if name in passed_class_obj and passed_class_obj[name] is None:
return False
return True
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
if len(unused_kwargs) > 0:
logger.warning(f"Keyword arguments {unused_kwargs} not recognized.")
init_kwargs = {}
logger.warning(
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
)
# import it here to avoid circular import
from diffusers import pipelines
# 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
if class_name is None:
# edge case for when the pipeline was saved with safety_checker=None
init_kwargs[name] = None
continue
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
if class_name.startswith("Flax"):
class_name = class_name[4:]
is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None
sub_model_should_be_defined = True
# if the model is in a pipeline module, then we load it from the pipeline
if name in passed_class_obj:
# 1. check that passed_class_obj has correct parent class
if not is_pipeline_module and passed_class_obj[name] is not None:
if not is_pipeline_module:
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
importable_classes = LOADABLE_CLASSES[library_name]
@@ -557,14 +616,8 @@ class DiffusionPipeline(ConfigMixin):
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
f" {expected_class_obj}"
)
elif passed_class_obj[name] is None:
logger.warn(
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
f" that this might lead to problems when using {pipeline_class} and is not recommended."
)
sub_model_should_be_defined = False
else:
logger.warn(
logger.warning(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
" has the correct type"
)
@@ -584,7 +637,7 @@ class DiffusionPipeline(ConfigMixin):
importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
if loaded_sub_model is None and sub_model_should_be_defined:
if loaded_sub_model is None:
load_method_name = None
for class_name, class_candidate in class_candidates.items():
if class_candidate is not None and issubclass(class_obj, class_candidate):
@@ -592,7 +645,10 @@ class DiffusionPipeline(ConfigMixin):
if load_method_name is None:
none_module = class_obj.__module__
if none_module.startswith(DUMMY_MODULES_FOLDER) and "dummy" in none_module:
is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith(
TRANSFORMERS_DUMMY_MODULES_FOLDER
)
if is_dummy_path and "dummy" in none_module:
# call class_obj for nice error message of missing requirements
class_obj()
@@ -635,11 +691,13 @@ class DiffusionPipeline(ConfigMixin):
# 4. Potentially add passed objects if expected
missing_modules = set(expected_modules) - set(init_kwargs.keys())
if len(missing_modules) > 0 and missing_modules <= set(passed_class_obj.keys()):
passed_modules = list(passed_class_obj.keys())
optional_modules = pipeline_class._optional_components
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
for module in missing_modules:
init_kwargs[module] = passed_class_obj[module]
init_kwargs[module] = passed_class_obj.get(module, None)
elif len(missing_modules) > 0:
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys()))
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
raise ValueError(
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
)
@@ -648,6 +706,14 @@ class DiffusionPipeline(ConfigMixin):
model = pipeline_class(**init_kwargs)
return model
@staticmethod
def _get_signature_keys(obj):
parameters = inspect.signature(obj.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
expected_modules = set(required_parameters.keys()) - set(["self"])
return expected_modules, optional_parameters
@property
def components(self) -> Dict[str, Any]:
r"""
@@ -664,16 +730,18 @@ class DiffusionPipeline(ConfigMixin):
... StableDiffusionInpaintPipeline,
... )
>>> img2text = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> img2img = StableDiffusionImg2ImgPipeline(**img2text.components)
>>> inpaint = StableDiffusionInpaintPipeline(**img2text.components)
>>> text2img = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components)
>>> inpaint = StableDiffusionInpaintPipeline(**text2img.components)
```
Returns:
A dictionaly containing all the modules needed to initialize the pipeline.
"""
components = {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
expected_modules = set(inspect.signature(self.__init__).parameters.keys()) - set(["self"])
expected_modules, optional_parameters = self._get_signature_keys(self)
components = {
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
}
if set(components.keys()) != expected_modules:
raise ValueError(
@@ -699,7 +767,7 @@ class DiffusionPipeline(ConfigMixin):
return pil_images
def progress_bar(self, iterable):
def progress_bar(self, iterable=None, total=None):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
elif not isinstance(self._progress_bar_config, dict):
@@ -707,7 +775,12 @@ class DiffusionPipeline(ConfigMixin):
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
)
return tqdm(iterable, **self._progress_bar_config)
if iterable is not None:
return tqdm(iterable, **self._progress_bar_config)
elif total is not None:
return tqdm(total=total, **self._progress_bar_config)
else:
raise ValueError("Either `total` or `iterable` has to be defined.")
def set_progress_bar_config(self, **kwargs):
self._progress_bar_config = kwargs

View File

@@ -40,7 +40,7 @@ available a colab notebook to directly try them out.
| [pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | *Unconditional Image Generation* |
| [score_sde_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* |
| [score_sde_vp](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* |
| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb)
| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stochastic_karras_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | *Unconditional Image Generation* |

View File

@@ -15,13 +15,23 @@ else:
from ..utils.dummy_pt_objects import * # noqa F403
if is_torch_available() and is_transformers_available():
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
from .latent_diffusion import LDMTextToImagePipeline
from .stable_diffusion import (
CycleDiffusionPipeline,
StableDiffusionImageVariationPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
)
from .stable_diffusion_safe import StableDiffusionPipelineSafe
from .versatile_diffusion import (
VersatileDiffusionDualGuidedPipeline,
VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline,
)
from .vq_diffusion import VQDiffusionPipeline
@@ -29,6 +39,7 @@ if is_transformers_available() and is_onnx_available():
from .stable_diffusion import (
OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline,
OnnxStableDiffusionInpaintPipelineLegacy,
OnnxStableDiffusionPipeline,
StableDiffusionOnnxPipeline,
)

View File

@@ -0,0 +1,34 @@
from dataclasses import dataclass
from typing import List, Optional, Union
import numpy as np
import PIL
from PIL import Image
from ...utils import BaseOutput, is_torch_available, is_transformers_available
@dataclass
# Copied from diffusers.pipelines.stable_diffusion.__init__.StableDiffusionPipelineOutput with Stable->Alt
class AltDiffusionPipelineOutput(BaseOutput):
"""
Output class for Alt Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, or `None` if safety checking could not be performed.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: Optional[List[bool]]
if is_transformers_available() and is_torch_available():
from .modeling_roberta_series import RobertaSeriesModelWithTransformation
from .pipeline_alt_diffusion import AltDiffusionPipeline
from .pipeline_alt_diffusion_img2img import AltDiffusionImg2ImgPipeline

View File

@@ -0,0 +1,110 @@
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import nn
from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel
from transformers.utils import ModelOutput
@dataclass
class TransformationModelOutput(ModelOutput):
"""
Base class for text model's outputs that also contains a pooling of the last hidden states.
Args:
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The text embeddings obtained by applying the projection layer to the pooler_output.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
projection_state: Optional[torch.FloatTensor] = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class RobertaSeriesConfig(XLMRobertaConfig):
def __init__(
self,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
project_dim=512,
pooler_fn="cls",
learn_encoder=False,
use_attention_mask=True,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.project_dim = project_dim
self.pooler_fn = pooler_fn
self.learn_encoder = learn_encoder
self.use_attention_mask = use_attention_mask
class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
base_model_prefix = "roberta"
config_class = RobertaSeriesConfig
def __init__(self, config):
super().__init__(config)
self.roberta = XLMRobertaModel(config)
self.transformation = nn.Linear(config.hidden_size, config.project_dim)
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
):
r""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
projection_state = self.transformation(outputs.last_hidden_state)
return TransformationModelOutput(
projection_state=projection_state,
last_hidden_state=outputs.last_hidden_state,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@@ -0,0 +1,597 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, List, Optional, Union
import torch
from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import deprecate, logging
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
class AltDiffusionPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using Alt Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`RobertaSeriesModelWithTransformation`]):
Frozen text-encoder. Alt Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.RobertaSeriesModelWithTransformation),
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`XLMRobertaTokenizer`):
Tokenizer of class
[XLMRobertaTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.XLMRobertaTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: RobertaSeriesModelWithTransformation,
tokenizer: XLMRobertaTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding.
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
# fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device)
@property
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
return text_embeddings
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
else:
has_nsfw_concept = None
return image, has_nsfw_concept
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(self, prompt, height, width, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
text_embeddings.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
# 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image, has_nsfw_concept)
return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

View File

@@ -0,0 +1,614 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, List, Optional, Union
import numpy as np
import torch
import PIL
from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import PIL_INTERPOLATION, deprecate, logging
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
r"""
Pipeline for text-guided image to image generation using Alt Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`RobertaSeriesModelWithTransformation`]):
Frozen text-encoder. Alt Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.RobertaSeriesModelWithTransformation),
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`XLMRobertaTokenizer`):
Tokenizer of class
[XLMRobertaTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.XLMRobertaTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: RobertaSeriesModelWithTransformation,
tokenizer: XLMRobertaTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
# fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device)
@property
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
return text_embeddings
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
else:
has_nsfw_concept = None
return image, has_nsfw_concept
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(self, prompt, strength, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:]
return timesteps, num_inference_steps - t_start
def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
init_image = init_image.to(device=device, dtype=dtype)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size
deprecation_message = (
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
" images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note"
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
" your script to pass as many init images as text prompts to suppress this warning."
)
deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False)
additional_image_per_prompt = batch_size // init_latents.shape[0]
init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
)
else:
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype)
# get latents
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
latents = init_latents
return latents
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
init_image: Union[torch.FloatTensor, PIL.Image.Image],
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
init_image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
`init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
noise will be maximum and the denoising process will run for the full number of iterations specified in
`num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter will be modulated by `strength`.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 1. Check inputs
self.check_inputs(prompt, strength, callback_steps)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
# 4. Preprocess image
if isinstance(init_image, PIL.Image.Image):
init_image = preprocess(init_image)
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 6. Prepare latent variables
latents = self.prepare_latents(
init_image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator
)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 9. Post-processing
image = self.decode_latents(latents)
# 10. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image, has_nsfw_concept)
return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

View File

@@ -89,7 +89,11 @@ class DDIMPipeline(DiffusionPipeline):
generator = None
# Sample gaussian noise to begin loop
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
if isinstance(self.unet.sample_size, int):
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
else:
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)
if self.device.type == "mps":
# randn does not work reproducibly on mps
image = torch.randn(image_shape, generator=generator)

View File

@@ -70,14 +70,14 @@ class DDPMPipeline(DiffusionPipeline):
generated images.
"""
message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
new_config = dict(self.scheduler.config)
new_config["predict_epsilon"] = predict_epsilon
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
self.scheduler._internal_dict = FrozenDict(new_config)
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
@@ -94,7 +94,11 @@ class DDPMPipeline(DiffusionPipeline):
generator = None
# Sample gaussian noise to begin loop
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
if isinstance(self.unet.sample_size, int):
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
else:
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)
if self.device.type == "mps":
# randn does not work reproducibly on mps
image = torch.randn(image_shape, generator=generator)
@@ -110,9 +114,7 @@ class DDPMPipeline(DiffusionPipeline):
model_output = self.unet(image, t).sample
# 2. compute previous image: x_t -> x_t-1
image = self.scheduler.step(
model_output, t, image, generator=generator, predict_epsilon=predict_epsilon
).prev_sample
image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()

View File

@@ -60,13 +60,14 @@ class LDMTextToImagePipeline(DiffusionPipeline):
):
super().__init__()
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
height: Optional[int] = 256,
width: Optional[int] = 256,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 1.0,
eta: Optional[float] = 0.0,
@@ -79,9 +80,9 @@ class LDMTextToImagePipeline(DiffusionPipeline):
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to 256):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 256):
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -106,6 +107,9 @@ class LDMTextToImagePipeline(DiffusionPipeline):
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if isinstance(prompt, str):
batch_size = 1

View File

@@ -17,12 +17,13 @@ from ...schedulers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import PIL_INTERPOLATION
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)

View File

@@ -72,7 +72,7 @@ image.save("astronaut_rides_horse.png")
# make sure you're logged in with `huggingface-cli login`
from diffusers import StableDiffusionPipeline, DDIMScheduler
scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
@@ -91,7 +91,7 @@ image.save("astronaut_rides_horse.png")
# make sure you're logged in with `huggingface-cli login`
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
lms = LMSDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
@@ -120,7 +120,7 @@ from diffusers import CycleDiffusionPipeline, DDIMScheduler
# load the pipeline
# make sure you're logged in with `huggingface-cli login`
model_id_or_path = "CompVis/stable-diffusion-v1-4"
scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler")
scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler")
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")
# let's download an initial image

View File

@@ -6,7 +6,14 @@ import numpy as np
import PIL
from PIL import Image
from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
from ...utils import (
BaseOutput,
is_flax_available,
is_onnx_available,
is_torch_available,
is_transformers_available,
is_transformers_version,
)
@dataclass
@@ -33,12 +40,19 @@ if is_transformers_available() and is_torch_available():
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
from .safety_checker import StableDiffusionSafetyChecker
if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0"):
from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline
else:
from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline
if is_transformers_available() and is_onnx_available():
from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline
from .pipeline_onnx_stable_diffusion_inpaint import OnnxStableDiffusionInpaintPipeline
from .pipeline_onnx_stable_diffusion_inpaint_legacy import OnnxStableDiffusionInpaintPipelineLegacy
if is_transformers_available() and is_flax_available():
import flax
@@ -49,15 +63,14 @@ if is_transformers_available() and is_flax_available():
Output class for Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
images (`np.ndarray`)
Array of shape `(batch_size, height, width, num_channels)` with images from the diffusion pipeline.
nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
images: np.ndarray
nsfw_content_detected: List[bool]
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState

View File

@@ -20,13 +20,14 @@ import torch
import PIL
from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler
from ...utils import deprecate, logging
from ...utils import PIL_INTERPOLATION, deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
@@ -37,7 +38,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
@@ -132,6 +133,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
@@ -142,6 +144,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -159,8 +162,8 @@ class CycleDiffusionPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -169,6 +172,32 @@ class CycleDiffusionPipeline(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -178,6 +207,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
@@ -194,9 +224,14 @@ class CycleDiffusionPipeline(DiffusionPipeline):
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
@@ -209,7 +244,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
self.enable_attention_slicing(None)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self):
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
@@ -220,12 +255,17 @@ class CycleDiffusionPipeline(DiffusionPipeline):
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device("cuda")
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
# fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device)
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
@@ -301,7 +341,17 @@ class CycleDiffusionPipeline(DiffusionPipeline):
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
@@ -337,7 +387,17 @@ class CycleDiffusionPipeline(DiffusionPipeline):
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
@@ -415,7 +475,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:]
return timesteps
return timesteps, num_inference_steps - t_start
def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
init_image = init_image.to(device=device, dtype=dtype)
@@ -468,7 +528,6 @@ class CycleDiffusionPipeline(DiffusionPipeline):
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -548,7 +607,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.get_timesteps(num_inference_steps, strength, device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 6. Prepare latent variables
@@ -562,66 +621,70 @@ class CycleDiffusionPipeline(DiffusionPipeline):
generator = extra_step_kwargs.pop("generator", None)
# 8. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2)
source_latent_model_input = torch.cat([source_latents] * 2)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2)
source_latent_model_input = torch.cat([source_latents] * 2)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t)
# predict the noise residual
concat_latent_model_input = torch.stack(
[
source_latent_model_input[0],
latent_model_input[0],
source_latent_model_input[1],
latent_model_input[1],
],
dim=0,
)
concat_text_embeddings = torch.stack(
[
source_text_embeddings[0],
text_embeddings[0],
source_text_embeddings[1],
text_embeddings[1],
],
dim=0,
)
concat_noise_pred = self.unet(
concat_latent_model_input, t, encoder_hidden_states=concat_text_embeddings
).sample
# predict the noise residual
concat_latent_model_input = torch.stack(
[
source_latent_model_input[0],
latent_model_input[0],
source_latent_model_input[1],
latent_model_input[1],
],
dim=0,
)
concat_text_embeddings = torch.stack(
[
source_text_embeddings[0],
text_embeddings[0],
source_text_embeddings[1],
text_embeddings[1],
],
dim=0,
)
concat_noise_pred = self.unet(
concat_latent_model_input, t, encoder_hidden_states=concat_text_embeddings
).sample
# perform guidance
(
source_noise_pred_uncond,
noise_pred_uncond,
source_noise_pred_text,
noise_pred_text,
) = concat_noise_pred.chunk(4, dim=0)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
source_noise_pred = source_noise_pred_uncond + source_guidance_scale * (
source_noise_pred_text - source_noise_pred_uncond
)
# perform guidance
(
source_noise_pred_uncond,
noise_pred_uncond,
source_noise_pred_text,
noise_pred_text,
) = concat_noise_pred.chunk(4, dim=0)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
source_noise_pred = source_noise_pred_uncond + source_guidance_scale * (
source_noise_pred_text - source_noise_pred_uncond
)
# Sample source_latents from the posterior distribution.
prev_source_latents = posterior_sample(
self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs
)
# Compute noise.
noise = compute_noise(
self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs
)
source_latents = prev_source_latents
# Sample source_latents from the posterior distribution.
prev_source_latents = posterior_sample(
self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs
)
# Compute noise.
noise = compute_noise(
self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs
)
source_latents = prev_source_latents
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs
).prev_sample
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs
).prev_sample
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# call the callback, if provided
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 9. Post-processing
image = self.decode_latents(latents)

View File

@@ -23,6 +23,7 @@ import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.jax_utils import unreplicate
from flax.training.common_utils import shard
from packaging import version
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
@@ -34,7 +35,7 @@ from ...schedulers import (
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
)
from ...utils import logging
from ...utils import deprecate, logging
from . import FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
@@ -88,7 +89,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
self.dtype = dtype
if safety_checker is None:
logger.warn(
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -97,6 +98,27 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -106,6 +128,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
def prepare_inputs(self, prompt: Union[str, List[str]]):
if not isinstance(prompt, (str, list)):
@@ -160,12 +183,17 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
params: Union[Dict, FrozenDict],
prng_seed: jax.random.PRNGKey,
num_inference_steps: int = 50,
height: int = 512,
width: int = 512,
height: Optional[int] = None,
width: Optional[int] = None,
guidance_scale: float = 7.5,
latents: Optional[jnp.array] = None,
debug: bool = False,
neg_prompt_ids: jnp.array = None,
):
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -177,13 +205,22 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
batch_size = prompt_ids.shape[0]
max_length = prompt_ids.shape[-1]
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0]
if neg_prompt_ids is None:
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
).input_ids
else:
uncond_input = neg_prompt_ids
uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings])
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
latents_shape = (
batch_size,
self.unet.in_channels,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if latents is None:
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
else:
@@ -244,14 +281,14 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
params: Union[Dict, FrozenDict],
prng_seed: jax.random.PRNGKey,
num_inference_steps: int = 50,
height: int = 512,
width: int = 512,
height: Optional[int] = None,
width: Optional[int] = None,
guidance_scale: float = 7.5,
latents: jnp.array = None,
return_dict: bool = True,
jit: bool = False,
debug: bool = False,
**kwargs,
neg_prompt_ids: jnp.array = None,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -259,9 +296,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to 512):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -279,9 +316,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
jit (`bool`, defaults to `False`):
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
@@ -296,13 +330,36 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
element is a list of `bool`s denoting whether the corresponding generated image likely represents
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if jit:
images = _p_generate(
self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
self,
prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
debug,
neg_prompt_ids,
)
else:
images = self._generate(
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
debug,
neg_prompt_ids,
)
if self.safety_checker is not None:
@@ -322,6 +379,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
images = images.reshape(num_devices, batch_size, height, width, 3)
else:
images = np.asarray(images)
has_nsfw_concept = False
if not return_dict:
@@ -333,10 +391,29 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
# TODO: maybe use a config dict instead of so many static argnums
@partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9))
def _p_generate(
pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
pipe,
prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
debug,
neg_prompt_ids,
):
return pipe._generate(
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
debug,
neg_prompt_ids,
)

View File

@@ -41,6 +41,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
vae_encoder: OnnxRuntimeModel,
@@ -51,6 +53,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -81,6 +84,22 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules(
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,
@@ -91,6 +110,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
r"""
@@ -185,7 +205,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
if isinstance(prompt, str):
batch_size = 1
@@ -261,8 +280,10 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample
latents = np.array(latents)
scheduler_output = self.scheduler.step(
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
)
latents = scheduler_output.prev_sample.numpy()
# call the callback, if provided
if callback is not None and i % callback_steps == 0:

View File

@@ -25,7 +25,7 @@ from ...configuration_utils import FrozenDict
from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import deprecate, logging
from ...utils import PIL_INTERPOLATION, deprecate, logging
from . import StableDiffusionPipelineOutput
@@ -35,7 +35,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
return 2.0 * image - 1.0
@@ -77,6 +77,8 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
vae_encoder: OnnxRuntimeModel,
@@ -87,6 +89,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -117,7 +120,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
@@ -127,6 +130,12 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules(
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,
@@ -137,6 +146,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
@@ -231,7 +241,6 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -401,8 +410,10 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample
latents = latents.numpy()
scheduler_output = self.scheduler.step(
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
)
latents = scheduler_output.prev_sample.numpy()
# call the callback, if provided
if callback is not None and i % callback_steps == 0:

View File

@@ -25,7 +25,7 @@ from ...configuration_utils import FrozenDict
from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import deprecate, logging
from ...utils import PIL_INTERPOLATION, deprecate, logging
from . import StableDiffusionPipelineOutput
@@ -44,7 +44,7 @@ def prepare_mask_and_masked_image(image, mask, latents_shape):
image_mask = np.array(mask.convert("L").resize((latents_shape[1] * 8, latents_shape[0] * 8)))
masked_image = image * (image_mask < 127.5)
mask = mask.resize((latents_shape[1], latents_shape[0]), PIL.Image.NEAREST)
mask = mask.resize((latents_shape[1], latents_shape[0]), PIL_INTERPOLATION["nearest"])
mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
@@ -90,6 +90,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
vae_encoder: OnnxRuntimeModel,
@@ -100,6 +102,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
logger.info("`OnnxStableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
@@ -131,7 +134,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
@@ -141,6 +144,12 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules(
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,
@@ -151,6 +160,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
@@ -236,8 +246,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
prompt: Union[str, List[str]],
image: PIL.Image.Image,
mask_image: PIL.Image.Image,
height: int = 512,
width: int = 512,
height: Optional[int] = 512,
width: Optional[int] = 512,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -249,7 +259,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -312,6 +321,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
@@ -408,9 +418,9 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
# concat latents, mask, masked_image_latnets in the channel dimension
latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1)
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
latent_model_input = latent_model_input.cpu().numpy()
latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1)
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
@@ -424,8 +434,10 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample
latents = latents.numpy()
scheduler_output = self.scheduler.step(
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
)
latents = scheduler_output.prev_sample.numpy()
# call the callback, if provided
if callback is not None and i % callback_steps == 0:

View File

@@ -0,0 +1,456 @@
import inspect
from typing import Callable, List, Optional, Union
import numpy as np
import torch
import PIL
from transformers import CLIPFeatureExtractor, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...onnx_utils import OnnxRuntimeModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
return 2.0 * image - 1.0
def preprocess_mask(mask, scale_factor=8):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST)
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
mask = 1 - mask # repaint white, keep black
return mask
class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
r"""
Pipeline for text-guided image inpainting using Stable Diffusion. This is a *legacy feature* for Onnx pipelines to
provide compatibility with StableDiffusionInpaintPipelineLegacy and may be removed in the future.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
vae_encoder: OnnxRuntimeModel
vae_decoder: OnnxRuntimeModel
text_encoder: OnnxRuntimeModel
tokenizer: CLIPTokenizer
unet: OnnxRuntimeModel
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor
def __init__(
self,
vae_encoder: OnnxRuntimeModel,
vae_decoder: OnnxRuntimeModel,
text_encoder: OnnxRuntimeModel,
tokenizer: CLIPTokenizer,
unet: OnnxRuntimeModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules(
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
if not np.array_equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt] * batch_size
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="np",
)
uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
return text_embeddings
def __call__(
self,
prompt: Union[str, List[str]],
init_image: Union[np.ndarray, PIL.Image.Image],
mask_image: Union[np.ndarray, PIL.Image.Image],
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[np.random.RandomState] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: Optional[int] = 1,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
init_image (`nd.ndarray` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process. This is the image whose masked region will be inpainted.
mask_image (`nd.ndarray` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.uu
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
`init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
noise will be maximum and the denoising process will run for the full number of iterations specified in
`num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter will be modulated by `strength`.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (?) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`np.random.RandomState`, *optional*):
A np.random.RandomState to make generation deterministic.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if generator is None:
generator = np.random
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
if isinstance(init_image, PIL.Image.Image):
init_image = preprocess(init_image)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
text_embeddings = self._encode_prompt(
prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
latents_dtype = text_embeddings.dtype
init_image = init_image.astype(latents_dtype)
# encode the init image into latents and scale the latents
init_latents = self.vae_encoder(sample=init_image)[0]
init_latents = 0.18215 * init_latents
# Expand init_latents for batch_size and num_images_per_prompt
init_latents = np.concatenate([init_latents] * num_images_per_prompt, axis=0)
init_latents_orig = init_latents
# preprocess mask
if not isinstance(mask_image, np.ndarray):
mask_image = preprocess_mask(mask_image, 8)
mask_image = mask_image.astype(latents_dtype)
mask = np.concatenate([mask_image] * num_images_per_prompt, axis=0)
# check sizes
if not mask.shape == init_latents.shape:
raise ValueError("The mask and init_image should be the same size!")
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps.numpy()[-init_timestep]
timesteps = np.array([timesteps] * batch_size * num_images_per_prompt)
# add noise to latents using the timesteps
noise = generator.randn(*init_latents.shape).astype(latents_dtype)
init_latents = self.scheduler.add_noise(
torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps)
)
init_latents = init_latents.numpy()
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (?) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to ? in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:].numpy()
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
).prev_sample
latents = latents.numpy()
init_latents_proper = self.scheduler.add_noise(
torch.from_numpy(init_latents_orig), torch.from_numpy(noise), torch.from_numpy(np.array([t]))
)
init_latents_proper = init_latents_proper.numpy()
latents = (init_latents_proper * mask) + (latents * (1 - mask))
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
)
image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="np"
).pixel_values.astype(image.dtype)
# There will throw an error if use safety_checker batchsize>1
images, has_nsfw_concept = [], []
for i in range(image.shape[0]):
image_i, has_nsfw_concept_i = self.safety_checker(
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
)
images.append(image_i)
has_nsfw_concept.append(has_nsfw_concept_i[0])
image = np.concatenate(images)
else:
has_nsfw_concept = None
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

View File

@@ -18,6 +18,7 @@ from typing import Callable, List, Optional, Union
import torch
from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
@@ -66,6 +67,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
@@ -83,6 +85,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -113,8 +116,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -123,6 +126,33 @@ class StableDiffusionPipeline(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -132,6 +162,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_xformers_memory_efficient_attention(self):
r"""
@@ -165,9 +197,14 @@ class StableDiffusionPipeline(DiffusionPipeline):
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
@@ -178,7 +215,23 @@ class StableDiffusionPipeline(DiffusionPipeline):
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self):
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding.
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
@@ -189,12 +242,17 @@ class StableDiffusionPipeline(DiffusionPipeline):
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device("cuda")
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
# fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device)
@property
def _execution_device(self):
r"""
@@ -248,7 +306,17 @@ class StableDiffusionPipeline(DiffusionPipeline):
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
@@ -284,7 +352,17 @@ class StableDiffusionPipeline(DiffusionPipeline):
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
@@ -349,7 +427,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // 8, width // 8)
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
@@ -369,8 +447,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
def __call__(
self,
prompt: Union[str, List[str]],
height: int = 512,
width: int = 512,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -382,7 +460,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -390,9 +467,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to 512):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -438,6 +515,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps)
@@ -476,25 +556,29 @@ class StableDiffusionPipeline(DiffusionPipeline):
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# call the callback, if provided
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 8. Post-processing
image = self.decode_latents(latents)

View File

@@ -0,0 +1,481 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, List, Optional, Union
import torch
import PIL
from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableDiffusionImageVariationPipeline(DiffusionPipeline):
r"""
Pipeline to generate variations from an input image using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
image_encoder ([`CLIPVisionModelWithProjection`]):
Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
vae: AutoencoderKL,
image_encoder: CLIPVisionModelWithProjection,
unet: UNet2DConditionModel,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
if safety_checker is None and requires_safety_checker:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
image_encoder=image_encoder,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.image_encoder, self.vae, self.safety_checker]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
image_embeddings = self.image_encoder(image).image_embeds
image_embeddings = image_embeddings.unsqueeze(1)
# duplicate image embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = image_embeddings.shape
image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
if do_classifier_free_guidance:
uncond_embeddings = torch.zeros_like(image_embeddings)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
image_embeddings = torch.cat([uncond_embeddings, image_embeddings])
return image_embeddings
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
else:
has_nsfw_concept = None
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(self, image, height, width, callback_steps):
if (
not isinstance(image, torch.Tensor)
and not isinstance(image, PIL.Image.Image)
and not isinstance(image, list)
):
raise ValueError(
f"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `list` but is {type(image)}"
)
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad()
def __call__(
self,
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
The image or images to guide the image generation. If you provide a tensor, it needs to comply with the
configuration of
[this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
`CLIPFeatureExtractor`
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(image, height, width, callback_steps)
# 2. Define call parameters
if isinstance(image, PIL.Image.Image):
batch_size = 1
elif isinstance(image, list):
batch_size = len(image)
else:
batch_size = image.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input image
image_embeddings = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
image_embeddings.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
# 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

View File

@@ -20,6 +20,7 @@ import torch
import PIL
from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
@@ -33,7 +34,7 @@ from ...schedulers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import deprecate, logging
from ...utils import PIL_INTERPOLATION, deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
@@ -44,7 +45,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
@@ -78,6 +79,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
def __init__(
@@ -96,6 +98,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -126,8 +129,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -136,6 +139,33 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -145,6 +175,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
@@ -161,9 +193,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
@@ -176,7 +213,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
self.enable_attention_slicing(None)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self):
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
@@ -187,12 +224,17 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device("cuda")
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
# fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device)
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
@@ -268,7 +310,17 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
@@ -304,7 +356,17 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
@@ -380,7 +442,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:]
return timesteps
return timesteps, num_inference_steps - t_start
def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
init_image = init_image.to(device=device, dtype=dtype)
@@ -431,7 +493,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -510,7 +571,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.get_timesteps(num_inference_steps, strength, device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 6. Prepare latent variables
@@ -522,25 +583,29 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 8. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# call the callback, if provided
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 9. Post-processing
image = self.decode_latents(latents)

View File

@@ -20,6 +20,7 @@ import torch
import PIL
from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
@@ -35,16 +36,88 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def prepare_mask_and_masked_image(image, mask):
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
"""
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
``image`` and ``1`` for the ``mask``.
mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
Args:
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
Raises:
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
(ot the other way around).
Returns:
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
dimensions: ``batch x channels x height x width``.
"""
if isinstance(image, torch.Tensor):
if not isinstance(mask, torch.Tensor):
raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
# Batch single image
if image.ndim == 3:
assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
image = image.unsqueeze(0)
# Batch and add channel dim for single mask
if mask.ndim == 2:
mask = mask.unsqueeze(0).unsqueeze(0)
# Batch single mask or add channel dim
if mask.ndim == 3:
# Single batched mask, no channel dim or single mask not batched but channel dim
if mask.shape[0] == 1:
mask = mask.unsqueeze(0)
# Batched masks no channel dim
else:
mask = mask.unsqueeze(1)
assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
# Check image is in [-1, 1]
if image.min() < -1 or image.max() > 1:
raise ValueError("Image should be in [-1, 1] range")
# Check mask is in [0, 1]
if mask.min() < 0 or mask.max() > 1:
raise ValueError("Mask should be in [0, 1] range")
# Binarize mask
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
# Image as float32
image = image.to(dtype=torch.float32)
elif isinstance(mask, torch.Tensor):
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
else:
if isinstance(image, PIL.Image.Image):
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
if isinstance(mask, PIL.Image.Image):
mask = np.array(mask.convert("L"))
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
@@ -78,6 +151,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
@@ -88,6 +162,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -119,8 +194,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
new_config["skip_prk_steps"] = True
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -129,6 +204,33 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -138,6 +240,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
@@ -154,9 +258,14 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
@@ -169,7 +278,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
self.enable_attention_slicing(None)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self):
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
@@ -180,12 +289,17 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device("cuda")
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
# fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
def enable_xformers_memory_efficient_attention(self):
r"""
@@ -261,7 +375,17 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
@@ -297,7 +421,17 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
@@ -367,7 +501,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // 8, width // 8)
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
@@ -389,7 +523,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
mask = torch.nn.functional.interpolate(
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
)
mask = mask.to(device=device, dtype=dtype)
masked_image = masked_image.to(device=device, dtype=dtype)
@@ -417,8 +553,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
prompt: Union[str, List[str]],
image: Union[torch.FloatTensor, PIL.Image.Image],
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
height: int = 512,
width: int = 512,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -430,7 +566,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -446,9 +581,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
instead of 3, so the expected shape would be `(B, H, W, 1)`.
height (`int`, *optional*, defaults to 512):
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -494,6 +629,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs
self.check_inputs(prompt, height, width, callback_steps)
@@ -517,7 +655,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps_tensor = self.scheduler.timesteps
timesteps = self.scheduler.timesteps
# 6. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
@@ -561,29 +699,32 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 10. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
# concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# call the callback, if provided
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 11. Post-processing
image = self.decode_latents(latents)

View File

@@ -20,6 +20,7 @@ import torch
import PIL
from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
@@ -33,7 +34,7 @@ from ...schedulers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import deprecate, logging
from ...utils import PIL_INTERPOLATION, deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
@@ -44,18 +45,18 @@ logger = logging.get_logger(__name__)
def preprocess_image(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
def preprocess_mask(mask):
def preprocess_mask(mask, scale_factor=8):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
@@ -91,6 +92,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
def __init__(
@@ -109,6 +111,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -139,8 +142,8 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -149,6 +152,33 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -158,6 +188,8 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
@@ -174,9 +206,14 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
@@ -189,7 +226,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
self.enable_attention_slicing(None)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self):
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
@@ -200,12 +237,17 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device("cuda")
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
# fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
def enable_xformers_memory_efficient_attention(self):
r"""
@@ -281,7 +323,17 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
@@ -317,7 +369,17 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
@@ -395,7 +457,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:]
return timesteps
return timesteps, num_inference_steps - t_start
def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator):
init_image = init_image.to(device=self.device, dtype=dtype)
@@ -430,7 +492,6 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -512,11 +573,11 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
init_image = preprocess_image(init_image)
if not isinstance(mask_image, torch.FloatTensor):
mask_image = preprocess_mask(mask_image)
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.get_timesteps(num_inference_steps, strength, device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 6. Prepare latent variables
@@ -533,29 +594,33 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 9. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
latents = (init_latents_proper * mask) + (latents * (1 - mask))
latents = (init_latents_proper * mask) + (latents * (1 - mask))
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# call the callback, if provided
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 10. Post-processing
image = self.decode_latents(latents)

View File

@@ -0,0 +1,555 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, List, Optional, Union
import numpy as np
import torch
import PIL
from diffusers.utils import is_accelerate_available
from transformers import CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image):
# resize to multiple of 64
width, height = image.size
width = width - width % 64
height = height - height % 64
image = image.resize((width, height))
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
return image
class StableDiffusionUpscalePipeline(DiffusionPipeline):
r"""
Pipeline for text-guided image super-resolution using Stable Diffusion 2.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
low_res_scheduler ([`SchedulerMixin`]):
A scheduler used to add initial noise to the low res conditioning image. It must be an instance of
[`DDPMScheduler`].
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
low_res_scheduler: DDPMScheduler,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
max_noise_level: int = 350,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
low_res_scheduler=low_res_scheduler,
scheduler=scheduler,
)
self.register_to_config(max_noise_level=max_noise_level)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
return text_embeddings
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents with 0.18215->0.08333
def decode_latents(self, latents):
latents = 1 / 0.08333 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
def check_inputs(self, prompt, image, noise_level, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if (
not isinstance(image, torch.Tensor)
and not isinstance(image, PIL.Image.Image)
and not isinstance(image, list)
):
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}"
)
# verify batch size of prompt and image are same if image is a list or tensor
if isinstance(image, list) or isinstance(image, torch.Tensor):
if isinstance(prompt, str):
batch_size = 1
else:
batch_size = len(prompt)
if isinstance(image, list):
image_batch_size = len(image)
else:
image_batch_size = image.shape[0]
if batch_size != image_batch_size:
raise ValueError(
f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}."
" Please make sure that passed `prompt` matches the batch size of `image`."
)
# check noise level
if noise_level > self.config.max_noise_level:
raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height, width)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]],
num_inference_steps: int = 75,
guidance_scale: float = 9.0,
noise_level: int = 20,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
`Image`, or tensor representing an image batch which will be upscaled. *
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 1. Check inputs
self.check_inputs(prompt, image, noise_level, callback_steps)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
# 4. Preprocess image
image = [image] if isinstance(image, PIL.Image.Image) else image
if isinstance(image, list):
image = [preprocess(img) for img in image]
image = torch.cat(image, dim=0)
image = image.to(dtype=text_embeddings.dtype, device=device)
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Add noise to image
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
if device.type == "mps":
# randn does not work reproducibly on mps
noise = torch.randn(image.shape, generator=generator, device="cpu", dtype=text_embeddings.dtype).to(device)
else:
noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings.dtype)
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
image = torch.cat([image] * 2) if do_classifier_free_guidance else image
noise_level = torch.cat([noise_level] * 2) if do_classifier_free_guidance else noise_level
# 6. Prepare latent variables
height, width = image.shape[2:]
num_channels_latents = self.vae.config.latent_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
text_embeddings.dtype,
device,
generator,
latents,
)
# 7. Check that sizes of image and latents match
num_channels_image = image.shape[1]
if num_channels_latents + num_channels_image != self.unet.config.in_channels:
raise ValueError(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_image`: {num_channels_image} "
f" = {num_channels_latents+num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input."
)
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = torch.cat([latent_model_input, image], dim=1)
# predict the noise residual
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=text_embeddings, class_labels=noise_level
).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 10. Post-processing
# make sure the VAE is in float32 mode, as it overflows in float16
self.vae.to(dtype=torch.float32)
image = self.decode_latents(latents.float())
# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)

View File

@@ -0,0 +1,72 @@
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Union
import numpy as np
import PIL
from PIL import Image
from ...utils import BaseOutput, is_torch_available, is_transformers_available
@dataclass
class SafetyConfig(object):
WEAK = {
"sld_warmup_steps": 15,
"sld_guidance_scale": 20,
"sld_threshold": 0.0,
"sld_momentum_scale": 0.0,
"sld_mom_beta": 0.0,
}
MEDIUM = {
"sld_warmup_steps": 10,
"sld_guidance_scale": 1000,
"sld_threshold": 0.01,
"sld_momentum_scale": 0.3,
"sld_mom_beta": 0.4,
}
STRONG = {
"sld_warmup_steps": 7,
"sld_guidance_scale": 2000,
"sld_threshold": 0.025,
"sld_momentum_scale": 0.5,
"sld_mom_beta": 0.7,
}
MAX = {
"sld_warmup_steps": 0,
"sld_guidance_scale": 5000,
"sld_threshold": 1.0,
"sld_momentum_scale": 0.5,
"sld_mom_beta": 0.7,
}
@dataclass
class StableDiffusionSafePipelineOutput(BaseOutput):
"""
Output class for Safe Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, or `None` if safety checking could not be performed.
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images that were flagged by the safety checker any may contain "not-safe-for-work"
(nsfw) content, or `None` if no safety check was performed or no images were flagged.
applied_safety_concept (`str`)
The safety concept that was applied for safety guidance, or `None` if safety guidance was disabled
"""
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: Optional[List[bool]]
unsafe_images: Optional[Union[List[PIL.Image.Image], np.ndarray]]
applied_safety_concept: Optional[str]
if is_transformers_available() and is_torch_available():
from .pipeline_stable_diffusion_safe import StableDiffusionPipelineSafe
from .safety_checker import SafeStableDiffusionSafetyChecker

View File

@@ -0,0 +1,764 @@
import inspect
import warnings
from typing import Callable, List, Optional, Union
import numpy as np
import torch
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import deprecate, is_accelerate_available, logging
from . import StableDiffusionSafePipelineOutput
from .safety_checker import SafeStableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableDiffusionPipelineSafe(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using Safe Latent Diffusion.
The implementation is based on the [`StableDiffusionPipeline`]
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
],
safety_checker: SafeStableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
safety_concept: Optional[str] = (
"an image showing hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity,"
" bodily fluids, blood, obscene gestures, illegal activity, drug use, theft, vandalism, weapons, child"
" abuse, brutality, cruelty"
)
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self._safety_text_concept = safety_concept
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@property
def safety_concept(self):
r"""
Getter method for the safety concept used with SLD
Returns:
`str`: The text describing the safety concept
"""
return self._safety_text_concept
@safety_concept.setter
def safety_concept(self, concept):
r"""
Setter method for the safety concept used with SLD
Args:
concept (`str`):
The text of the new safety concept
"""
self._safety_text_concept = concept
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device("cuda")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
enable_safety_guidance,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# Encode the safety concept text
if enable_safety_guidance:
safety_concept_input = self.tokenizer(
[self._safety_text_concept],
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
safety_embeddings = self.text_encoder(safety_concept_input.input_ids.to(self.device))[0]
# duplicate safety embeddings for each generation per prompt, using mps friendly method
seq_len = safety_embeddings.shape[1]
safety_embeddings = safety_embeddings.repeat(batch_size, num_images_per_prompt, 1)
safety_embeddings = safety_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance + sld, we need to do three forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing three forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, safety_embeddings])
else:
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
return text_embeddings
def run_safety_checker(self, image, device, dtype, enable_safety_guidance):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
flagged_images = None
if any(has_nsfw_concept):
logger.warning(
"Potential NSFW content was detected in one or more images. A black image will be returned"
" instead."
f" {'You may look at this images in the `unsafe_images` variable of the output at your own discretion.' if enable_safety_guidance else 'Try again with a different prompt and/or seed.'} "
)
flagged_images = np.zeros((2, *image.shape[1:]))
for idx, has_nsfw_concept in enumerate(has_nsfw_concept):
if has_nsfw_concept:
flagged_images[idx] = image[idx]
image[idx] = np.zeros(image[idx].shape) # black image
else:
has_nsfw_concept = None
flagged_images = None
return image, has_nsfw_concept, flagged_images
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs(self, prompt, height, width, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def perform_safety_guidance(
self,
enable_safety_guidance,
safety_momentum,
noise_guidance,
noise_pred_out,
i,
sld_guidance_scale,
sld_warmup_steps,
sld_threshold,
sld_momentum_scale,
sld_mom_beta,
):
# Perform SLD guidance
if enable_safety_guidance:
if safety_momentum is None:
safety_momentum = torch.zeros_like(noise_guidance)
noise_pred_text, noise_pred_uncond = noise_pred_out[0], noise_pred_out[1]
noise_pred_safety_concept = noise_pred_out[2]
# Equation 6
scale = torch.clamp(torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0)
# Equation 6
safety_concept_scale = torch.where(
(noise_pred_text - noise_pred_safety_concept) >= sld_threshold, torch.zeros_like(scale), scale
)
# Equation 4
noise_guidance_safety = torch.mul((noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale)
# Equation 7
noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum
# Equation 8
safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety
if i >= sld_warmup_steps: # Warmup
# Equation 3
noise_guidance = noise_guidance - noise_guidance_safety
return noise_guidance, safety_momentum
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
sld_guidance_scale: Optional[float] = 1000,
sld_warmup_steps: Optional[int] = 10,
sld_threshold: Optional[float] = 0.01,
sld_momentum_scale: Optional[float] = 0.3,
sld_mom_beta: Optional[float] = 0.4,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
sld_guidance_scale (`float`, *optional*, defaults to 1000):
Safe latent guidance as defined in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105).
`sld_guidance_scale` is defined as sS of Eq. 6. If set to be less than 1, safety guidance will be
disabled.
sld_warmup_steps (`int`, *optional*, defaults to 10):
Number of warmup steps for safety guidance. SLD will only be applied for diffusion steps greater than
`sld_warmup_steps`. `sld_warmup_steps` is defined as `delta` of [Safe Latent
Diffusion](https://arxiv.org/abs/2211.05105).
sld_threshold (`float`, *optional*, defaults to 0.01):
Threshold that separates the hyperplane between appropriate and inappropriate images. `sld_threshold`
is defined as `lamda` of Eq. 5 in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105).
sld_momentum_scale (`float`, *optional*, defaults to 0.3):
Scale of the SLD momentum to be added to the safety guidance at each diffusion step. If set to 0.0
momentum will be disabled. Momentum is already built up during warmup, i.e. for diffusion steps smaller
than `sld_warmup_steps`. `sld_momentum_scale` is defined as `sm` of Eq. 7 in [Safe Latent
Diffusion](https://arxiv.org/abs/2211.05105).
sld_mom_beta (`float`, *optional*, defaults to 0.4):
Defines how safety guidance momentum builds up. `sld_mom_beta` indicates how much of the previous
momentum will be kept. Momentum is already built up during warmup, i.e. for diffusion steps smaller
than `sld_warmup_steps`. `sld_mom_beta` is defined as `beta m` of Eq. 8 in [Safe Latent
Diffusion](https://arxiv.org/abs/2211.05105).
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
enable_safety_guidance = sld_guidance_scale > 1.0 and do_classifier_free_guidance
if not enable_safety_guidance:
warnings.warn("Safety checker disabled!")
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
text_embeddings.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
safety_momentum = None
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * (3 if enable_safety_guidance else 2))
if do_classifier_free_guidance
else latents
)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_out = noise_pred.chunk((3 if enable_safety_guidance else 2))
noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1]
# default classifier free guidance
noise_guidance = noise_pred_text - noise_pred_uncond
# Perform SLD guidance
if enable_safety_guidance:
if safety_momentum is None:
safety_momentum = torch.zeros_like(noise_guidance)
noise_pred_safety_concept = noise_pred_out[2]
# Equation 6
scale = torch.clamp(
torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0
)
# Equation 6
safety_concept_scale = torch.where(
(noise_pred_text - noise_pred_safety_concept) >= sld_threshold,
torch.zeros_like(scale),
scale,
)
# Equation 4
noise_guidance_safety = torch.mul(
(noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale
)
# Equation 7
noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum
# Equation 8
safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety
if i >= sld_warmup_steps: # Warmup
# Equation 3
noise_guidance = noise_guidance - noise_guidance_safety
noise_pred = noise_pred_uncond + guidance_scale * noise_guidance
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0:
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept, flagged_images = self.run_safety_checker(
image, device, text_embeddings.dtype, enable_safety_guidance
)
# 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if flagged_images is not None:
flagged_images = self.numpy_to_pil(flagged_images)
if not return_dict:
return (
image,
has_nsfw_concept,
self._safety_text_concept if enable_safety_guidance else None,
flagged_images,
)
return StableDiffusionSafePipelineOutput(
images=image,
nsfw_content_detected=has_nsfw_concept,
applied_safety_concept=self._safety_text_concept if enable_safety_guidance else None,
unsafe_images=flagged_images,
)

View File

@@ -0,0 +1,110 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
from ...utils import logging
logger = logging.get_logger(__name__)
def cosine_distance(image_embeds, text_embeds):
normalized_image_embeds = nn.functional.normalize(image_embeds)
normalized_text_embeds = nn.functional.normalize(text_embeds)
return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
class SafeStableDiffusionSafetyChecker(PreTrainedModel):
config_class = CLIPConfig
_no_split_modules = ["CLIPEncoderLayer"]
def __init__(self, config: CLIPConfig):
super().__init__(config)
self.vision_model = CLIPVisionModel(config.vision_config)
self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)
self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
@torch.no_grad()
def forward(self, clip_input, images):
pooled_output = self.vision_model(clip_input)[1] # pooled_output
image_embeds = self.visual_projection(pooled_output)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy()
cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()
result = []
batch_size = image_embeds.shape[0]
for i in range(batch_size):
result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
# increase this value to create a stronger `nfsw` filter
# at the cost of increasing the possibility of filtering benign images
adjustment = 0.0
for concept_idx in range(len(special_cos_dist[0])):
concept_cos = special_cos_dist[i][concept_idx]
concept_threshold = self.special_care_embeds_weights[concept_idx].item()
result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["special_scores"][concept_idx] > 0:
result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
adjustment = 0.01
for concept_idx in range(len(cos_dist[0])):
concept_cos = cos_dist[i][concept_idx]
concept_threshold = self.concept_embeds_weights[concept_idx].item()
result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["concept_scores"][concept_idx] > 0:
result_img["bad_concepts"].append(concept_idx)
result.append(result_img)
has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
return images, has_nsfw_concepts
@torch.no_grad()
def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
pooled_output = self.vision_model(clip_input)[1] # pooled_output
image_embeds = self.visual_projection(pooled_output)
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
cos_dist = cosine_distance(image_embeds, self.concept_embeds)
# increase this value to create a stronger `nsfw` filter
# at the cost of increasing the possibility of filtering benign images
adjustment = 0.0
special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment
# special_scores = special_scores.round(decimals=3)
special_care = torch.any(special_scores > 0, dim=1)
special_adjustment = special_care * 0.01
special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1])
concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
# concept_scores = concept_scores.round(decimals=3)
has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
return images, has_nsfw_concepts

View File

@@ -0,0 +1,16 @@
from ...utils import is_torch_available, is_transformers_available, is_transformers_version
if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0"):
from .modeling_text_unet import UNetFlatConditionModel
from .pipeline_versatile_diffusion import VersatileDiffusionPipeline
from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
else:
from ...utils.dummy_torch_and_transformers_objects import (
VersatileDiffusionDualGuidedPipeline,
VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,463 @@
import inspect
from typing import Callable, List, Optional, Union
import torch
import PIL.Image
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging
from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class VersatileDiffusionPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionMegaSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
tokenizer: CLIPTokenizer
image_feature_extractor: CLIPFeatureExtractor
text_encoder: CLIPTextModel
image_encoder: CLIPVisionModel
image_unet: UNet2DConditionModel
text_unet: UNet2DConditionModel
vae: AutoencoderKL
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
def __init__(
self,
tokenizer: CLIPTokenizer,
image_feature_extractor: CLIPFeatureExtractor,
text_encoder: CLIPTextModel,
image_encoder: CLIPVisionModel,
image_unet: UNet2DConditionModel,
text_unet: UNet2DConditionModel,
vae: AutoencoderKL,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
):
super().__init__()
self.register_modules(
tokenizer=tokenizer,
image_feature_extractor=image_feature_extractor,
text_encoder=text_encoder,
image_encoder=image_encoder,
image_unet=image_unet,
text_unet=text_unet,
vae=vae,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.image_unet.config.attention_head_dim // 2
self.image_unet.set_attention_slice(slice_size)
self.text_unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
@torch.no_grad()
def image_variation(
self,
image: Union[torch.FloatTensor, PIL.Image.Image],
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`):
The image prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples:
```py
>>> from diffusers import VersatileDiffusionPipeline
>>> import torch
>>> import requests
>>> from io import BytesIO
>>> from PIL import Image
>>> # let's download an initial image
>>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg"
>>> response = requests.get(url)
>>> image = Image.open(BytesIO(response.content)).convert("RGB")
>>> pipe = VersatileDiffusionPipeline.from_pretrained(
... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
>>> generator = torch.Generator(device="cuda").manual_seed(0)
>>> image = pipe.image_variation(image, generator=generator).images[0]
>>> image.save("./car_variation.png")
```
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
expected_components = inspect.signature(VersatileDiffusionImageVariationPipeline.__init__).parameters.keys()
components = {name: component for name, component in self.components.items() if name in expected_components}
return VersatileDiffusionImageVariationPipeline(**components)(
image=image,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
eta=eta,
generator=generator,
latents=latents,
output_type=output_type,
return_dict=return_dict,
callback=callback,
callback_steps=callback_steps,
)
@torch.no_grad()
def text_to_image(
self,
prompt: Union[str, List[str]],
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples:
```py
>>> from diffusers import VersatileDiffusionPipeline
>>> import torch
>>> pipe = VersatileDiffusionPipeline.from_pretrained(
... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
>>> generator = torch.Generator(device="cuda").manual_seed(0)
>>> image = pipe.text_to_image("an astronaut riding on a horse on mars", generator=generator).images[0]
>>> image.save("./astronaut.png")
```
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
expected_components = inspect.signature(VersatileDiffusionTextToImagePipeline.__init__).parameters.keys()
components = {name: component for name, component in self.components.items() if name in expected_components}
temp_pipeline = VersatileDiffusionTextToImagePipeline(**components)
output = temp_pipeline(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
eta=eta,
generator=generator,
latents=latents,
output_type=output_type,
return_dict=return_dict,
callback=callback,
callback_steps=callback_steps,
)
# swap the attention blocks back to the original state
temp_pipeline._swap_unet_attention_blocks()
return output
@torch.no_grad()
def dual_guided(
self,
prompt: Union[PIL.Image.Image, List[PIL.Image.Image]],
image: Union[str, List[str]],
text_to_image_strength: float = 0.5,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples:
```py
>>> from diffusers import VersatileDiffusionPipeline
>>> import torch
>>> import requests
>>> from io import BytesIO
>>> from PIL import Image
>>> # let's download an initial image
>>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg"
>>> response = requests.get(url)
>>> image = Image.open(BytesIO(response.content)).convert("RGB")
>>> text = "a red car in the sun"
>>> pipe = VersatileDiffusionPipeline.from_pretrained(
... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
>>> generator = torch.Generator(device="cuda").manual_seed(0)
>>> text_to_image_strength = 0.75
>>> image = pipe.dual_guided(
... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator
... ).images[0]
>>> image.save("./car_variation.png")
```
Returns:
[`~pipelines.stable_diffusion.ImagePipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
returning a tuple, the first element is a list with the generated images.
"""
expected_components = inspect.signature(VersatileDiffusionDualGuidedPipeline.__init__).parameters.keys()
components = {name: component for name, component in self.components.items() if name in expected_components}
temp_pipeline = VersatileDiffusionDualGuidedPipeline(**components)
output = temp_pipeline(
prompt=prompt,
image=image,
text_to_image_strength=text_to_image_strength,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
eta=eta,
generator=generator,
latents=latents,
output_type=output_type,
return_dict=return_dict,
callback=callback,
callback_steps=callback_steps,
)
temp_pipeline._revert_dual_attention()
return output

Some files were not shown because too many files have changed in this diff Show More