diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml index 542920d7f6..d51623e735 100644 --- a/.github/workflows/build_pr_documentation.yml +++ b/.github/workflows/build_pr_documentation.yml @@ -9,11 +9,8 @@ concurrency: jobs: build: - uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@use_hf_hub + uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main with: commit_sha: ${{ github.event.pull_request.head.sha }} pr_number: ${{ github.event.number }} package: diffusers - secrets: - token: ${{ secrets.HF_DOC_PUSH }} - comment_bot_token: ${{ secrets.HUGGINGFACE_PUSH }} diff --git a/.github/workflows/delete_doc_comment.yml b/.github/workflows/delete_doc_comment.yml index e1b2da9567..238dc0bdba 100644 --- a/.github/workflows/delete_doc_comment.yml +++ b/.github/workflows/delete_doc_comment.yml @@ -7,10 +7,7 @@ on: jobs: delete: - uses: huggingface/doc-builder/.github/workflows/delete_doc_comment.yml@use_hf_hub + uses: huggingface/doc-builder/.github/workflows/delete_doc_comment.yml@main with: pr_number: ${{ github.event.number }} package: diffusers - secrets: - token: ${{ secrets.HF_DOC_PUSH }} - comment_bot_token: ${{ secrets.HUGGINGFACE_PUSH }} diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index dc1c482aa0..55a9bd68de 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -60,6 +60,7 @@ jobs: run: | python -m pip install -e .[quality,test] python -m pip install git+https://github.com/huggingface/accelerate + python -m pip install -U git+https://github.com/huggingface/transformers - name: Environment run: | @@ -127,6 +128,7 @@ jobs: ${CONDA_RUN} python -m pip install -e .[quality,test] ${CONDA_RUN} python -m pip install --pre torch==${MPS_TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/test/cpu ${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate + ${CONDA_RUN} python -m pip install -U git+https://github.com/huggingface/transformers - name: Environment shell: arch -arch arm64 bash {0} diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index 2beb05e8ea..4bab00b7ee 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -62,6 +62,7 @@ jobs: run: | python -m pip install -e .[quality,test] python -m pip install git+https://github.com/huggingface/accelerate + python -m pip install -U git+https://github.com/huggingface/transformers - name: Environment run: | @@ -131,6 +132,7 @@ jobs: run: | python -m pip install -e .[quality,test,training] python -m pip install git+https://github.com/huggingface/accelerate + python -m pip install -U git+https://github.com/huggingface/transformers - name: Environment run: | diff --git a/README.md b/README.md index 64cbd15aab..ff523d060c 100644 --- a/README.md +++ b/README.md @@ -152,15 +152,7 @@ it before the pipeline and pass it to `from_pretrained`. ```python from diffusers import LMSDiscreteScheduler -lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler") - -pipe = StableDiffusionPipeline.from_pretrained( - "runwayml/stable-diffusion-v1-5", - revision="fp16", - torch_dtype=torch.float16, - scheduler=lms, -) -pipe = pipe.to("cuda") +pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) prompt = "a photo of an astronaut riding a horse on mars" image = pipe(prompt).images[0] @@ -353,7 +345,8 @@ Textual Inversion is a technique for capturing novel concepts from a small numbe ## Stable Diffusion Community Pipelines -The release of Stable Diffusion as an open source model has fostered a lot of interesting ideas and experimentation. Our [Community Examples folder](https://github.com/huggingface/diffusers/tree/main/examples/community) contains many ideas worth exploring, like interpolating to create animated videos, using CLIP Guidance for additional prompt fidelity, term weighting, and much more! [Take a look](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview) and [contribute your own](https://huggingface.co/docs/diffusers/using-diffusers/contribute_pipeline). +The release of Stable Diffusion as an open source model has fostered a lot of interesting ideas and experimentation. +Our [Community Examples folder](https://github.com/huggingface/diffusers/tree/main/examples/community) contains many ideas worth exploring, like interpolating to create animated videos, using CLIP Guidance for additional prompt fidelity, term weighting, and much more! [Take a look](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview) and [contribute your own](https://huggingface.co/docs/diffusers/using-diffusers/contribute_pipeline). ## Other Examples @@ -402,10 +395,14 @@ image.save("ddpm_generated_image.png") - [Unconditional Latent Diffusion](https://huggingface.co/CompVis/ldm-celebahq-256) - [Unconditional Diffusion with continuous scheduler](https://huggingface.co/google/ncsnpp-ffhq-1024) -**Other Notebooks**: +**Other Image Notebooks**: * [image-to-image generation with Stable Diffusion](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg), * [tweak images via repeated Stable Diffusion seeds](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg), +**Diffusers for Other Modalities**: +* [Molecule conformation generation](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/geodiff_molecule_conformation.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg), +* [Model-based reinforcement learning](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_with_diffusers.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg), + ### Web Demos If you just want to play around with some web demos, you can try out the following 🚀 Spaces: | Model | Hugging Face Spaces | diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index d8efb5eee3..9571444883 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -10,6 +10,8 @@ - sections: - local: using-diffusers/loading title: "Loading Pipelines, Models, and Schedulers" + - local: using-diffusers/schedulers + title: "Using different Schedulers" - local: using-diffusers/configuration title: "Configuring Pipelines, Models, and Schedulers" - local: using-diffusers/custom_pipeline_overview @@ -29,6 +31,14 @@ - local: using-diffusers/contribute_pipeline title: "How to contribute a Pipeline" title: "Pipelines for Inference" + - sections: + - local: using-diffusers/rl + title: "Reinforcement Learning" + - local: using-diffusers/audio + title: "Audio" + - local: using-diffusers/other-modalities + title: "Other Modalities" + title: "Taking Diffusers Beyond Images" title: "Using Diffusers" - sections: - local: optimization/fp16 @@ -78,6 +88,8 @@ - sections: - local: api/pipelines/overview title: "Overview" + - local: api/pipelines/alt_diffusion + title: "AltDiffusion" - local: api/pipelines/cycle_diffusion title: "Cycle Diffusion" - local: api/pipelines/ddim @@ -94,13 +106,23 @@ title: "Score SDE VE" - local: api/pipelines/stable_diffusion title: "Stable Diffusion" + - local: api/pipelines/stable_diffusion_2 + title: "Stable Diffusion 2" + - local: api/pipelines/stable_diffusion_safe + title: "Safe Stable Diffusion" - local: api/pipelines/stochastic_karras_ve title: "Stochastic Karras VE" - local: api/pipelines/dance_diffusion title: "Dance Diffusion" + - local: api/pipelines/versatile_diffusion + title: "Versatile Diffusion" - local: api/pipelines/vq_diffusion title: "VQ Diffusion" - local: api/pipelines/repaint title: "RePaint" title: "Pipelines" + - sections: + - local: api/experimental/rl + title: "RL Planning" + title: "Experimental Features" title: "API" diff --git a/docs/source/api/configuration.mdx b/docs/source/api/configuration.mdx index 45176f55b0..423c31f462 100644 --- a/docs/source/api/configuration.mdx +++ b/docs/source/api/configuration.mdx @@ -15,9 +15,9 @@ specific language governing permissions and limitations under the License. In Diffusers, schedulers of type [`schedulers.scheduling_utils.SchedulerMixin`], and models of type [`ModelMixin`] inherit from [`ConfigMixin`] which conveniently takes care of storing all parameters that are passed to the respective `__init__` methods in a JSON-configuration file. -TODO(PVP) - add example and better info here - ## ConfigMixin + [[autodoc]] ConfigMixin + - load_config - from_config - save_config diff --git a/docs/source/api/experimental/rl.mdx b/docs/source/api/experimental/rl.mdx new file mode 100644 index 0000000000..65abb06e75 --- /dev/null +++ b/docs/source/api/experimental/rl.mdx @@ -0,0 +1,15 @@ + + +# TODO + +Coming soon! \ No newline at end of file diff --git a/docs/source/api/pipelines/alt_diffusion.mdx b/docs/source/api/pipelines/alt_diffusion.mdx new file mode 100644 index 0000000000..8d7d795d76 --- /dev/null +++ b/docs/source/api/pipelines/alt_diffusion.mdx @@ -0,0 +1,83 @@ + + +# AltDiffusion + +AltDiffusion was proposed in [AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities](https://arxiv.org/abs/2211.06679) by Zhongzhi Chen, Guang Liu, Bo-Wen Zhang, Fulong Ye, Qinghong Yang, Ledell Wu + +The abstract of the paper is the following: + +*In this work, we present a conceptually simple and effective method to train a strong bilingual multimodal representation model. Starting from the pretrained multimodal representation model CLIP released by OpenAI, we switched its text encoder with a pretrained multilingual text encoder XLM-R, and aligned both languages and image representations by a two-stage training schema consisting of teacher learning and contrastive learning. We validate our method through evaluations of a wide range of tasks. We set new state-of-the-art performances on a bunch of tasks including ImageNet-CN, Flicker30k- CN, and COCO-CN. Further, we obtain very close performances with CLIP on almost all tasks, suggesting that one can simply alter the text encoder in CLIP for extended capabilities such as multilingual understanding.* + + +*Overview*: + +| Pipeline | Tasks | Colab | Demo +|---|---|:---:|:---:| +| [pipeline_alt_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py) | *Text-to-Image Generation* | - | - +| [pipeline_alt_diffusion_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py) | *Image-to-Image Text-Guided Generation* | - |- + +## Tips + +- AltDiffusion is conceptually exaclty the same as [Stable Diffusion](./api/pipelines/stable_diffusion). + +- *Run AltDiffusion* + +AltDiffusion can be tested very easily with the [`AltDiffusionPipeline`], [`AltDiffusionImg2ImgPipeline`] and the `"BAAI/AltDiffusion-m9"` checkpoint exactly in the same way it is shown in the [Conditional Image Generation Guide](./using-diffusers/conditional_image_generation) and the [Image-to-Image Generation Guide](./using-diffusers/img2img). + +- *How to load and use different schedulers.* + +The alt diffusion pipeline uses [`DDIMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the alt diffusion pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc. +To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following: + +```python +>>> from diffusers import AltDiffusionPipeline, EulerDiscreteScheduler + +>>> pipeline = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion-m9") +>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config) + +>>> # or +>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("BAAI/AltDiffusion-m9", subfolder="scheduler") +>>> pipeline = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion-m9", scheduler=euler_scheduler) +``` + + +- *How to convert all use cases with multiple or single pipeline* + +If you want to use all possible use cases in a single `DiffusionPipeline` we recommend using the `components` functionality to instantiate all components in the most memory-efficient way: + +```python +>>> from diffusers import ( +... AltDiffusionPipeline, +... AltDiffusionImg2ImgPipeline, +... ) + +>>> text2img = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion-m9") +>>> img2img = AltDiffusionImg2ImgPipeline(**text2img.components) + +>>> # now you can use text2img(...) and img2img(...) just like the call methods of each respective pipeline +``` + +## AltDiffusionPipelineOutput +[[autodoc]] pipelines.alt_diffusion.AltDiffusionPipelineOutput + +## AltDiffusionPipeline +[[autodoc]] AltDiffusionPipeline + - __call__ + - enable_attention_slicing + - disable_attention_slicing + +## AltDiffusionImg2ImgPipeline +[[autodoc]] AltDiffusionImg2ImgPipeline + - __call__ + - enable_attention_slicing + - disable_attention_slicing diff --git a/docs/source/api/pipelines/cycle_diffusion.mdx b/docs/source/api/pipelines/cycle_diffusion.mdx index 50d2a5c87e..8eecd3d624 100644 --- a/docs/source/api/pipelines/cycle_diffusion.mdx +++ b/docs/source/api/pipelines/cycle_diffusion.mdx @@ -39,7 +39,7 @@ from diffusers import CycleDiffusionPipeline, DDIMScheduler # load the pipeline # make sure you're logged in with `huggingface-cli login` model_id_or_path = "CompVis/stable-diffusion-v1-4" -scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler") +scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler") pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda") # let's download an initial image diff --git a/docs/source/api/pipelines/latent_diffusion.mdx b/docs/source/api/pipelines/latent_diffusion.mdx index 4ade13e67b..370d014f5a 100644 --- a/docs/source/api/pipelines/latent_diffusion.mdx +++ b/docs/source/api/pipelines/latent_diffusion.mdx @@ -39,9 +39,9 @@ The original codebase can be found [here](https://github.com/CompVis/latent-diff ## LDMTextToImagePipeline -[[autodoc]] pipelines.latent_diffusion.pipeline_latent_diffusion.LDMTextToImagePipeline +[[autodoc]] LDMTextToImagePipeline - __call__ ## LDMSuperResolutionPipeline -[[autodoc]] pipelines.latent_diffusion.pipeline_latent_diffusion_superresolution.LDMSuperResolutionPipeline +[[autodoc]] LDMSuperResolutionPipeline - __call__ diff --git a/docs/source/api/pipelines/overview.mdx b/docs/source/api/pipelines/overview.mdx index d68961a2fc..eed8e0d0b0 100644 --- a/docs/source/api/pipelines/overview.mdx +++ b/docs/source/api/pipelines/overview.mdx @@ -44,11 +44,13 @@ available a colab notebook to directly try them out. | Pipeline | Paper | Tasks | Colab |---|---|:---:|:---:| +| [alt_diffusion](./api/pipelines/alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | - | [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation | | [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation | | [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | | [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | | [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation | +| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image | | [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation | | [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation | | [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | @@ -56,7 +58,14 @@ available a colab notebook to directly try them out. | [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) | [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) | [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) -| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | +| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-to-Image Generation | +| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Image Inpainting | +| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image | +| [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb) +| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | +| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation | +| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation | +| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation | | [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation | diff --git a/docs/source/api/pipelines/repaint.mdx b/docs/source/api/pipelines/repaint.mdx index 0b7de8a457..ce262daffa 100644 --- a/docs/source/api/pipelines/repaint.mdx +++ b/docs/source/api/pipelines/repaint.mdx @@ -54,7 +54,7 @@ original_image = download_image(img_url).resize((256, 256)) mask_image = download_image(mask_url).resize((256, 256)) # Load the RePaint scheduler and pipeline based on a pretrained DDPM model -scheduler = RePaintScheduler.from_config("google/ddpm-ema-celebahq-256") +scheduler = RePaintScheduler.from_pretrained("google/ddpm-ema-celebahq-256") pipe = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=scheduler) pipe = pipe.to("cuda") diff --git a/docs/source/api/pipelines/stable_diffusion.mdx b/docs/source/api/pipelines/stable_diffusion.mdx index 26d6a210ad..70c4abaaf6 100644 --- a/docs/source/api/pipelines/stable_diffusion.mdx +++ b/docs/source/api/pipelines/stable_diffusion.mdx @@ -34,17 +34,21 @@ For more details about how Stable Diffusion works and how it differs from the ba ### How to load and use different schedulers. The stable diffusion pipeline uses [`PNDMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc. -To use a different scheduler, you can pass the `scheduler` argument to `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following: +To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following: ```python -from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler +>>> from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler -euler_scheduler = EulerDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler") -pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=euler_scheduler) +>>> pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") +>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config) + +>>> # or +>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler") +>>> pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=euler_scheduler) ``` -### How to conver all use cases with multiple or single pipeline +### How to convert all use cases with multiple or single pipeline If you want to use all possible use cases in a single `DiffusionPipeline` you can either: - Make use of the [Stable Diffusion Mega Pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community#stable-diffusion-mega) or @@ -57,11 +61,11 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca ... StableDiffusionInpaintPipeline, ... ) ->>> img2text = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") ->>> img2img = StableDiffusionImg2ImgPipeline(**img2text.components) ->>> inpaint = StableDiffusionInpaintPipeline(**img2text.components) +>>> text2img = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") +>>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components) +>>> inpaint = StableDiffusionInpaintPipeline(**text2img.components) ->>> # now you can use img2text(...), img2img(...), inpaint(...) just like the call methods of each respective pipeline +>>> # now you can use text2img(...), img2img(...), inpaint(...) just like the call methods of each respective pipeline ``` ## StableDiffusionPipelineOutput @@ -72,6 +76,8 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca - __call__ - enable_attention_slicing - disable_attention_slicing + - enable_vae_slicing + - disable_vae_slicing ## StableDiffusionImg2ImgPipeline [[autodoc]] StableDiffusionImg2ImgPipeline @@ -84,3 +90,17 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca - __call__ - enable_attention_slicing - disable_attention_slicing + + +## StableDiffusionImageVariationPipeline +[[autodoc]] StableDiffusionImageVariationPipeline + - __call__ + - enable_attention_slicing + - disable_attention_slicing + + +## StableDiffusionUpscalePipeline +[[autodoc]] StableDiffusionUpscalePipeline + - __call__ + - enable_attention_slicing + - disable_attention_slicing diff --git a/docs/source/api/pipelines/stable_diffusion_2.mdx b/docs/source/api/pipelines/stable_diffusion_2.mdx new file mode 100644 index 0000000000..5df9195034 --- /dev/null +++ b/docs/source/api/pipelines/stable_diffusion_2.mdx @@ -0,0 +1,142 @@ + + +# Stable diffusion 2 + +Stable Diffusion 2 is a text-to-image _latent diffusion_ model built upon the work of [Stable Diffusion 1](https://stability.ai/blog/stable-diffusion-public-release). +The project to train Stable Diffusion 2 was led by Robin Rombach and Katherine Crowson from [Stability AI](https://stability.ai/) and [LAION](https://laion.ai/). + +*The Stable Diffusion 2.0 release includes robust text-to-image models trained using a brand new text encoder (OpenCLIP), developed by LAION with support from Stability AI, which greatly improves the quality of the generated images compared to earlier V1 releases. The text-to-image models in this release can generate images with default resolutions of both 512x512 pixels and 768x768 pixels. +These models are trained on an aesthetic subset of the [LAION-5B dataset](https://laion.ai/blog/laion-5b/) created by the DeepFloyd team at Stability AI, which is then further filtered to remove adult content using [LAION’s NSFW filter](https://openreview.net/forum?id=M3Y74vmsMcY).* + +For more details about how Stable Diffusion 2 works and how it differs from Stable Diffusion 1, please refer to the official [launch announcement post](https://stability.ai/blog/stable-diffusion-v2-release). + +## Tips + +### Available checkpoints: + +Note that the architecture is more or less identical to [Stable Diffusion 1](./api/pipelines/stable_diffusion) so please refer to [this page](./api/pipelines/stable_diffusion) for API documentation. + +- *Text-to-Image (512x512 resolution)*: [stabilityai/stable-diffusion-2-base](https://huggingface.co/stabilityai/stable-diffusion-2-base) with [`StableDiffusionPipeline`] +- *Text-to-Image (768x768 resolution)*: [stabilityai/stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) with [`StableDiffusionPipeline`] +- *Image Inpainting (512x512 resolution)*: [stabilityai/stable-diffusion-2-inpainting](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) with [`StableDiffusionInpaintPipeline`] +- *Image Upscaling (x4 resolution resolution)*: [stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) [`StableDiffusionUpscalePipeline`] + +We recommend using the [`DPMSolverMultistepScheduler`] as it's currently the fastest scheduler there is. + +- *Text-to-Image (512x512 resolution)*: + +```python +from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler +import torch + +repo_id = "stabilityai/stable-diffusion-2-base" +pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16") + +pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) +pipe = pipe.to("cuda") + +prompt = "High quality photo of an astronaut riding a horse in space" +image = pipe(prompt, num_inference_steps=25).images[0] +image.save("astronaut.png") +``` + +- *Text-to-Image (768x768 resolution)*: + +```python +from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler +import torch + +repo_id = "stabilityai/stable-diffusion-2" +pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16") + +pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) +pipe = pipe.to("cuda") + +prompt = "High quality photo of an astronaut riding a horse in space" +image = pipe(prompt, guidance_scale=9, num_inference_steps=25).images[0] +image.save("astronaut.png") +``` + +- *Image Inpainting (512x512 resolution)*: + +```python +import PIL +import requests +import torch +from io import BytesIO + +from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler + + +def download_image(url): + response = requests.get(url) + return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + +img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" +mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + +init_image = download_image(img_url).resize((512, 512)) +mask_image = download_image(mask_url).resize((512, 512)) + +repo_id = "stabilityai/stable-diffusion-2-inpainting" +pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16") + +pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) +pipe = pipe.to("cuda") + +prompt = "Face of a yellow cat, high resolution, sitting on a park bench" +image = pipe(prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=25).images[0] + +image.save("yellow_cat.png") +``` + +- *Image Upscaling (x4 resolution resolution)*: [stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) [`StableDiffusionUpscalePipeline`] + +```python +import requests +from PIL import Image +from io import BytesIO +from diffusers import StableDiffusionUpscalePipeline +import torch + +# load model and scheduler +model_id = "stabilityai/stable-diffusion-x4-upscaler" +pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16) +pipeline = pipeline.to("cuda") + +# let's download an image +url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png" +response = requests.get(url) +low_res_img = Image.open(BytesIO(response.content)).convert("RGB") +low_res_img = low_res_img.resize((128, 128)) +prompt = "a white cat" +upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0] +upscaled_image.save("upsampled_cat.png") +``` + +### How to load and use different schedulers. + +The stable diffusion pipeline uses [`DDIMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc. +To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following: + +```python +>>> from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler + +>>> pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2") +>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config) + +>>> # or +>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler") +>>> pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", scheduler=euler_scheduler) +``` diff --git a/docs/source/api/pipelines/stable_diffusion_safe.mdx b/docs/source/api/pipelines/stable_diffusion_safe.mdx new file mode 100644 index 0000000000..81fc59d392 --- /dev/null +++ b/docs/source/api/pipelines/stable_diffusion_safe.mdx @@ -0,0 +1,90 @@ + + +# Safe Stable Diffusion + +Safe Stable Diffusion was proposed in [Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models](https://arxiv.org/abs/2211.05105) and mitigates the well known issue that models like Stable Diffusion that are trained on unfiltered, web-crawled datasets tend to suffer from inappropriate degeneration. For instance Stable Diffusion may unexpectedly generate nudity, violence, images depicting self-harm, or otherwise offensive content. +Safe Stable Diffusion is an extension to the Stable Diffusion that drastically reduces content like this. + +The abstract of the paper is the following: + +*Text-conditioned image generation models have recently achieved astonishing results in image quality and text alignment and are consequently employed in a fast-growing number of applications. Since they are highly data-driven, relying on billion-sized datasets randomly scraped from the internet, they also suffer, as we demonstrate, from degenerated and biased human behavior. In turn, they may even reinforce such biases. To help combat these undesired side effects, we present safe latent diffusion (SLD). Specifically, to measure the inappropriate degeneration due to unfiltered and imbalanced training sets, we establish a novel image generation test bed-inappropriate image prompts (I2P)-containing dedicated, real-world image-to-text prompts covering concepts such as nudity and violence. As our exhaustive empirical evaluation demonstrates, the introduced SLD removes and suppresses inappropriate image parts during the diffusion process, with no additional training required and no adverse effect on overall image quality or text alignment.* + + +*Overview*: + +| Pipeline | Tasks | Colab | Demo +|---|---|:---:|:---:| +| [pipeline_stable_diffusion_safe.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb) | - + +## Tips + +- Safe Stable Diffusion may also be used with weights of [Stable Diffusion](./api/pipelines/stable_diffusion). + +### Run Safe Stable Diffusion + +Safe Stable Diffusion can be tested very easily with the [`StableDiffusionPipelineSafe`], and the `"AIML-TUDA/stable-diffusion-safe"` checkpoint exactly in the same way it is shown in the [Conditional Image Generation Guide](./using-diffusers/conditional_image_generation). + +### Interacting with the Safety Concept + +To check and edit the currently used safety concept, use the `safety_concept` property of [`StableDiffusionPipelineSafe`] +```python +>>> from diffusers import StableDiffusionPipelineSafe + +>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe") +>>> pipeline.safety_concept +``` +For each image generation the active concept is also contained in [`StableDiffusionSafePipelineOutput`]. + +### Using pre-defined safety configurations + +You may use the 4 configurations defined in the [Safe Latent Diffusion paper](https://arxiv.org/abs/2211.05105) as follows: + +```python +>>> from diffusers import StableDiffusionPipelineSafe +>>> from diffusers.pipelines.stable_diffusion_safe import SafetyConfig + +>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe") +>>> prompt = "the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c. leyendecker" +>>> out = pipeline(prompt=prompt, **SafetyConfig.MAX) +``` + +The following configurations are available: `SafetyConfig.WEAK`, `SafetyConfig.MEDIUM`, `SafetyConfig.STRONg`, and `SafetyConfig.MAX`. + +### How to load and use different schedulers. + +The safe stable diffusion pipeline uses [`PNDMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc. +To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following: + +```python +>>> from diffusers import StableDiffusionPipelineSafe, EulerDiscreteScheduler + +>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe") +>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config) + +>>> # or +>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("AIML-TUDA/stable-diffusion-safe", subfolder="scheduler") +>>> pipeline = StableDiffusionPipelineSafe.from_pretrained( +... "AIML-TUDA/stable-diffusion-safe", scheduler=euler_scheduler +... ) +``` + + +## StableDiffusionSafePipelineOutput +[[autodoc]] pipelines.stable_diffusion_safe.StableDiffusionSafePipelineOutput + +## StableDiffusionPipelineSafe +[[autodoc]] StableDiffusionPipelineSafe + - __call__ + - enable_attention_slicing + - disable_attention_slicing + diff --git a/docs/source/api/pipelines/versatile_diffusion.mdx b/docs/source/api/pipelines/versatile_diffusion.mdx new file mode 100644 index 0000000000..f557c5b0aa --- /dev/null +++ b/docs/source/api/pipelines/versatile_diffusion.mdx @@ -0,0 +1,73 @@ + + +# VersatileDiffusion + +VersatileDiffusion was proposed in [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) by Xingqian Xu, Zhangyang Wang, Eric Zhang, Kai Wang, Humphrey Shi . + +The abstract of the paper is the following: + +*The recent advances in diffusion models have set an impressive milestone in many generation tasks. Trending works such as DALL-E2, Imagen, and Stable Diffusion have attracted great interest in academia and industry. Despite the rapid landscape changes, recent new approaches focus on extensions and performance rather than capacity, thus requiring separate models for separate tasks. In this work, we expand the existing single-flow diffusion pipeline into a multi-flow network, dubbed Versatile Diffusion (VD), that handles text-to-image, image-to-text, image-variation, and text-variation in one unified model. Moreover, we generalize VD to a unified multi-flow multimodal diffusion framework with grouped layers, swappable streams, and other propositions that can process modalities beyond images and text. Through our experiments, we demonstrate that VD and its underlying framework have the following merits: a) VD handles all subtasks with competitive quality; b) VD initiates novel extensions and applications such as disentanglement of style and semantic, image-text dual-guided generation, etc.; c) Through these experiments and applications, VD provides more semantic insights of the generated outputs.* + +## Tips + +- VersatileDiffusion is conceptually very similar as [Stable Diffusion](./api/pipelines/stable_diffusion), but instead of providing just a image data stream conditioned on text, VersatileDiffusion provides both a image and text data stream and can be conditioned on both text and image. + +### *Run VersatileDiffusion* + +You can both load the memory intensive "all-in-one" [`VersatileDiffusionPipeline`] that can run all tasks +with the same class as shown in [`VersatileDiffusionPipeline.text_to_image`], [`VersatileDiffusionPipeline.image_variation`], and [`VersatileDiffusionPipeline.dual_guided`] + +**or** + +You can run the individual pipelines which are much more memory efficient: + +- *Text-to-Image*: [`VersatileDiffusionTextToImagePipeline.__call__`] +- *Image Variation*: [`VersatileDiffusionImageVariationPipeline.__call__`] +- *Dual Text and Image Guided Generation*: [`VersatileDiffusionDualGuidedPipeline.__call__`] + +### *How to load and use different schedulers.* + +The versatile diffusion pipelines uses [`DDIMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the alt diffusion pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc. +To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following: + +```python +>>> from diffusers import VersatileDiffusionPipeline, EulerDiscreteScheduler + +>>> pipeline = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion") +>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config) + +>>> # or +>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("shi-labs/versatile-diffusion", subfolder="scheduler") +>>> pipeline = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", scheduler=euler_scheduler) +``` + +## VersatileDiffusionPipeline +[[autodoc]] VersatileDiffusionPipeline + +## VersatileDiffusionTextToImagePipeline +[[autodoc]] VersatileDiffusionTextToImagePipeline + - __call__ + - enable_attention_slicing + - disable_attention_slicing + +## VersatileDiffusionImageVariationPipeline +[[autodoc]] VersatileDiffusionImageVariationPipeline + - __call__ + - enable_attention_slicing + - disable_attention_slicing + +## VersatileDiffusionDualGuidedPipeline +[[autodoc]] VersatileDiffusionDualGuidedPipeline + - __call__ + - enable_attention_slicing + - disable_attention_slicing diff --git a/docs/source/index.mdx b/docs/source/index.mdx index bae507ac11..975ff47b61 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -34,11 +34,13 @@ available a colab notebook to directly try them out. | Pipeline | Paper | Tasks | Colab |---|---|:---:|:---:| +| [alt_diffusion](./api/pipelines/alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | | [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation | | [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation | | [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | | [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | | [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation | +| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image | | [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation | | [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation | | [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | @@ -46,7 +48,14 @@ available a colab notebook to directly try them out. | [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) | [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) | [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) +| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-to-Image Generation | +| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Image Inpainting | +| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image | +| [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb) | [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | +| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation | +| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation | +| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation | | [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation | **Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers. diff --git a/docs/source/optimization/fp16.mdx b/docs/source/optimization/fp16.mdx index 4371daacc9..49fe3876bd 100644 --- a/docs/source/optimization/fp16.mdx +++ b/docs/source/optimization/fp16.mdx @@ -117,6 +117,34 @@ image = pipe(prompt).images[0] There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM! + +## Sliced VAE decode for larger batches + +To decode large batches of images with limited VRAM, or to enable batches with 32 images or more, you can use sliced VAE decode that decodes the batch latents one image at a time. + +You likely want to couple this with [`~StableDiffusionPipeline.enable_attention_slicing`] or [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`] to further minimize memory use. + +To perform the VAE decode one image at a time, invoke [`~StableDiffusionPipeline.enable_vae_slicing`] in your pipeline before inference. For example: + +```Python +import torch +from diffusers import StableDiffusionPipeline + +pipe = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + revision="fp16", + torch_dtype=torch.float16, +) +pipe = pipe.to("cuda") + +prompt = "a photo of an astronaut riding a horse on mars" +pipe.enable_vae_slicing() +images = pipe([prompt] * 32).images +``` + +You may see a small performance boost in VAE decode on multi-image batches. There should be no performance impact on single-image batches. + + ## Offloading to CPU with accelerate for memory savings For additional memory savings, you can offload the weights to CPU and load them to GPU when performing the forward pass. diff --git a/docs/source/quicktour.mdx b/docs/source/quicktour.mdx index 463780a072..a50b476c3d 100644 --- a/docs/source/quicktour.mdx +++ b/docs/source/quicktour.mdx @@ -41,7 +41,7 @@ In this guide though, you'll use [`DiffusionPipeline`] for text-to-image generat ```python >>> from diffusers import DiffusionPipeline ->>> generator = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") +>>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") ``` The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components. @@ -49,13 +49,13 @@ Because the model consists of roughly 1.4 billion parameters, we strongly recomm You can move the generator object to GPU, just like you would in PyTorch. ```python ->>> generator.to("cuda") +>>> pipeline.to("cuda") ``` -Now you can use the `generator` on your text prompt: +Now you can use the `pipeline` on your text prompt: ```python ->>> image = generator("An image of a squirrel in Picasso style").images[0] +>>> image = pipeline("An image of a squirrel in Picasso style").images[0] ``` The output is by default wrapped into a [PIL Image object](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class). @@ -82,7 +82,7 @@ just like we did before only that now you need to pass your `AUTH_TOKEN`: ```python >>> from diffusers import DiffusionPipeline ->>> generator = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN) +>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN) ``` If you do not pass your authentication token you will see that the diffusion system will not be correctly @@ -102,7 +102,7 @@ token. Assuming that `"./stable-diffusion-v1-5"` is the local path to the cloned you can also load the pipeline as follows: ```python ->>> generator = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5") +>>> pipeline = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5") ``` Running the pipeline is then identical to the code above as it's the same model architecture. @@ -115,19 +115,20 @@ Running the pipeline is then identical to the code above as it's the same model Diffusion systems can be used with multiple different [schedulers](./api/schedulers) each with their pros and cons. By default, Stable Diffusion runs with [`PNDMScheduler`], but it's very simple to -use a different scheduler. *E.g.* if you would instead like to use the [`LMSDiscreteScheduler`] scheduler, +use a different scheduler. *E.g.* if you would instead like to use the [`EulerDiscreteScheduler`] scheduler, you could use it as follows: ```python ->>> from diffusers import LMSDiscreteScheduler +>>> from diffusers import EulerDiscreteScheduler ->>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler") +>>> pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN) ->>> generator = StableDiffusionPipeline.from_pretrained( -... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, use_auth_token=AUTH_TOKEN -... ) +>>> # change scheduler to Euler +>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config) ``` +For more in-detail information on how to change between schedulers, please refer to the [Using Schedulers](./using-diffusers/schedulers) guide. + [Stability AI's](https://stability.ai/) Stable Diffusion model is an impressive image generation model and can do much more than just generating images from text. We have dedicated a whole documentation page, just for Stable Diffusion [here](./conceptual/stable_diffusion). diff --git a/docs/source/using-diffusers/audio.mdx b/docs/source/using-diffusers/audio.mdx new file mode 100644 index 0000000000..5a5c2241ca --- /dev/null +++ b/docs/source/using-diffusers/audio.mdx @@ -0,0 +1,16 @@ + + +# Using Diffusers for audio + +The [`DanceDiffusionPipeline`] can be used to generate audio rapidly! +More coming soon! \ No newline at end of file diff --git a/docs/source/using-diffusers/conditional_image_generation.mdx b/docs/source/using-diffusers/conditional_image_generation.mdx index 6273a71d4c..5ed27ac917 100644 --- a/docs/source/using-diffusers/conditional_image_generation.mdx +++ b/docs/source/using-diffusers/conditional_image_generation.mdx @@ -44,5 +44,3 @@ You can save the image by simply calling: ```python >>> image.save("image_of_squirrel_painting.png") ``` - - diff --git a/docs/source/using-diffusers/loading.mdx b/docs/source/using-diffusers/loading.mdx index 2cb980ea61..c97ad5c5d0 100644 --- a/docs/source/using-diffusers/loading.mdx +++ b/docs/source/using-diffusers/loading.mdx @@ -19,7 +19,7 @@ In the following we explain in-detail how to easily load: - *Complete Diffusion Pipelines* via the [`DiffusionPipeline.from_pretrained`] - *Diffusion Models* via [`ModelMixin.from_pretrained`] -- *Schedulers* via [`ConfigMixin.from_config`] +- *Schedulers* via [`SchedulerMixin.from_pretrained`] ## Loading pipelines @@ -137,15 +137,15 @@ from diffusers import DiffusionPipeline, EulerDiscreteScheduler, DPMSolverMultis repo_id = "runwayml/stable-diffusion-v1-5" -scheduler = EulerDiscreteScheduler.from_config(repo_id, subfolder="scheduler") +scheduler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler") # or -# scheduler = DPMSolverMultistepScheduler.from_config(repo_id, subfolder="scheduler") +# scheduler = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler") stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, scheduler=scheduler) ``` Three things are worth paying attention to here. -- First, the scheduler is loaded with [`ConfigMixin.from_config`] since it only depends on a configuration file and not any parameterized weights +- First, the scheduler is loaded with [`SchedulerMixin.from_pretrained`] - Second, the scheduler is loaded with a function argument, called `subfolder="scheduler"` as the configuration of stable diffusion's scheduling is defined in a [subfolder of the official pipeline repository](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/scheduler) - Third, the scheduler instance can simply be passed with the `scheduler` keyword argument to [`DiffusionPipeline.from_pretrained`]. This works because the [`StableDiffusionPipeline`] defines its scheduler with the `scheduler` attribute. It's not possible to use a different name, such as `sampler=scheduler` since `sampler` is not a defined keyword for [`StableDiffusionPipeline.__init__`] @@ -337,8 +337,8 @@ model = UNet2DModel.from_pretrained(repo_id) ## Loading schedulers -Schedulers cannot be loaded via a `from_pretrained` method, but instead rely on [`ConfigMixin.from_config`]. Schedulers are **not parameterized** or **trained**, but instead purely defined by a configuration file. -Therefore the loading method was given a different name here. +Schedulers rely on [`SchedulerMixin.from_pretrained`]. Schedulers are **not parameterized** or **trained**, but instead purely defined by a configuration file. +For consistency, we use the same method name as we do for models or pipelines, but no weights are loaded in this case. In constrast to pipelines or models, loading schedulers does not consume any significant amount of memory and the same configuration file can often be used for a variety of different schedulers. For example, all of: @@ -367,13 +367,13 @@ from diffusers import ( repo_id = "runwayml/stable-diffusion-v1-5" -ddpm = DDPMScheduler.from_config(repo_id, subfolder="scheduler") -ddim = DDIMScheduler.from_config(repo_id, subfolder="scheduler") -pndm = PNDMScheduler.from_config(repo_id, subfolder="scheduler") -lms = LMSDiscreteScheduler.from_config(repo_id, subfolder="scheduler") -euler_anc = EulerAncestralDiscreteScheduler.from_config(repo_id, subfolder="scheduler") -euler = EulerDiscreteScheduler.from_config(repo_id, subfolder="scheduler") -dpm = DPMSolverMultistepScheduler.from_config(repo_id, subfolder="scheduler") +ddpm = DDPMScheduler.from_pretrained(repo_id, subfolder="scheduler") +ddim = DDIMScheduler.from_pretrained(repo_id, subfolder="scheduler") +pndm = PNDMScheduler.from_pretrained(repo_id, subfolder="scheduler") +lms = LMSDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler") +euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler") +euler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler") +dpm = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler") # replace `dpm` with any of `ddpm`, `ddim`, `pndm`, `lms`, `euler`, `euler_anc` pipeline = StableDiffusionPipeline.from_pretrained(repo_id, scheduler=dpm) diff --git a/docs/source/using-diffusers/other-modalities.mdx b/docs/source/using-diffusers/other-modalities.mdx new file mode 100644 index 0000000000..1dc0877adb --- /dev/null +++ b/docs/source/using-diffusers/other-modalities.mdx @@ -0,0 +1,20 @@ + + +# Using Diffusers with other modalities + +Diffusers is in the process of expanding to modalities other than images. + +Currently, one example is for [molecule conformation](https://www.nature.com/subjects/molecular-conformation#:~:text=Definition,to%20changes%20in%20their%20environment.) generation. +* Generate conformations in Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/geodiff_molecule_conformation.ipynb) + +More coming soon! \ No newline at end of file diff --git a/docs/source/using-diffusers/rl.mdx b/docs/source/using-diffusers/rl.mdx new file mode 100644 index 0000000000..6e18e07001 --- /dev/null +++ b/docs/source/using-diffusers/rl.mdx @@ -0,0 +1,18 @@ + + +# Using Diffusers for reinforcement learning + +Support for one RL model and related pipelines is included in the `experimental` source of diffusers. + +To try some of this in colab, please look at the following example: +* Model-based reinforcement learning on Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_with_diffusers.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg) diff --git a/docs/source/using-diffusers/schedulers.mdx b/docs/source/using-diffusers/schedulers.mdx new file mode 100644 index 0000000000..87ff789747 --- /dev/null +++ b/docs/source/using-diffusers/schedulers.mdx @@ -0,0 +1,262 @@ + + +# Schedulers + +Diffusion pipelines are inherently a collection of diffusion models and schedulers that are partly independent from each other. This means that one is able to switch out parts of the pipeline to better customize +a pipeline to one's use case. The best example of this are the [Schedulers](../api/schedulers.mdx). + +Whereas diffusion models usually simply define the forward pass from noise to a less noisy sample, +schedulers define the whole denoising process, *i.e.*: +- How many denoising steps? +- Stochastic or deterministic? +- What algorithm to use to find the denoised sample + +They can be quite complex and often define a trade-off between **denoising speed** and **denoising quality**. +It is extremely difficult to measure quantitatively which scheduler works best for a given diffusion pipeline, so it is often recommended to simply try out which works best. + +The following paragraphs shows how to do so with the 🧨 Diffusers library. + +## Load pipeline + +Let's start by loading the stable diffusion pipeline. +Remember that you have to be a registered user on the 🤗 Hugging Face Hub, and have "click-accepted" the [license](https://huggingface.co/runwayml/stable-diffusion-v1-5) in order to use stable diffusion. + +```python +from huggingface_hub import login +from diffusers import DiffusionPipeline +import torch + +# first we need to login with our access token +login() + +# Now we can download the pipeline +pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) +``` + +Next, we move it to GPU: + +```python +pipeline.to("cuda") +``` + +## Access the scheduler + +The scheduler is always one of the components of the pipeline and is usually called `"scheduler"`. +So it can be accessed via the `"scheduler"` property. + +```python +pipeline.scheduler +``` + +**Output**: +``` +PNDMScheduler { + "_class_name": "PNDMScheduler", + "_diffusers_version": "0.8.0.dev0", + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "clip_sample": false, + "num_train_timesteps": 1000, + "set_alpha_to_one": false, + "skip_prk_steps": true, + "steps_offset": 1, + "trained_betas": null +} +``` + +We can see that the scheduler is of type [`PNDMScheduler`]. +Cool, now let's compare the scheduler in its performance to other schedulers. +First we define a prompt on which we will test all the different schedulers: + +```python +prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition." +``` + +Next, we create a generator from a random seed that will ensure that we can generate similar images as well as run the pipeline: + +```python +generator = torch.Generator(device="cuda").manual_seed(8) +image = pipeline(prompt, generator=generator).images[0] +image +``` + +

+
+ +
+

+ + +## Changing the scheduler + +Now we show how easy it is to change the scheduler of a pipeline. Every scheduler has a property [`SchedulerMixin.compatibles`] +which defines all compatible schedulers. You can take a look at all available, compatible schedulers for the Stable Diffusion pipeline as follows. + +```python +pipeline.scheduler.compatibles +``` + +**Output**: +``` +[diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler, + diffusers.schedulers.scheduling_ddim.DDIMScheduler, + diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler, + diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler, + diffusers.schedulers.scheduling_pndm.PNDMScheduler, + diffusers.schedulers.scheduling_ddpm.DDPMScheduler, + diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler] +``` + +Cool, lots of schedulers to look at. Feel free to have a look at their respective class definitions: + +- [`LMSDiscreteScheduler`], +- [`DDIMScheduler`], +- [`DPMSolverMultistepScheduler`], +- [`EulerDiscreteScheduler`], +- [`PNDMScheduler`], +- [`DDPMScheduler`], +- [`EulerAncestralDiscreteScheduler`]. + +We will now compare the input prompt with all other schedulers. To change the scheduler of the pipeline you can make use of the +convenient [`ConfigMixin.config`] property in combination with the [`ConfigMixin.from_config`] function. + +```python +pipeline.scheduler.config +``` + +returns a dictionary of the configuration of the scheduler: + +**Output**: +``` +FrozenDict([('num_train_timesteps', 1000), + ('beta_start', 0.00085), + ('beta_end', 0.012), + ('beta_schedule', 'scaled_linear'), + ('trained_betas', None), + ('skip_prk_steps', True), + ('set_alpha_to_one', False), + ('steps_offset', 1), + ('_class_name', 'PNDMScheduler'), + ('_diffusers_version', '0.8.0.dev0'), + ('clip_sample', False)]) +``` + +This configuration can then be used to instantiate a scheduler +of a different class that is compatible with the pipeline. Here, +we change the scheduler to the [`DDIMScheduler`]. + +```python +from diffusers import DDIMScheduler + +pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) +``` + +Cool, now we can run the pipeline again to compare the generation quality. + +```python +generator = torch.Generator(device="cuda").manual_seed(8) +image = pipeline(prompt, generator=generator).images[0] +image +``` + +

+
+ +
+

+ + +## Compare schedulers + +So far we have tried running the stable diffusion pipeline with two schedulers: [`PNDMScheduler`] and [`DDIMScheduler`]. +A number of better schedulers have been released that can be run with much fewer steps, let's compare them here: + +[`LMSDiscreteScheduler`] usually leads to better results: + +```python +from diffusers import LMSDiscreteScheduler + +pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) + +generator = torch.Generator(device="cuda").manual_seed(8) +image = pipeline(prompt, generator=generator).images[0] +image +``` + +

+
+ +
+

+ + +[`EulerDiscreteScheduler`] and [`EulerAncestralDiscreteScheduler`] can generate high quality results with as little as 30 steps. + +```python +from diffusers import EulerDiscreteScheduler + +pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config) + +generator = torch.Generator(device="cuda").manual_seed(8) +image = pipeline(prompt, generator=generator, num_inference_steps=30).images[0] +image +``` + +

+
+ +
+

+ + +and: + +```python +from diffusers import EulerAncestralDiscreteScheduler + +pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) + +generator = torch.Generator(device="cuda").manual_seed(8) +image = pipeline(prompt, generator=generator, num_inference_steps=30).images[0] +image +``` + +

+
+ +
+

+ + +At the time of writing this doc [`DPMSolverMultistepScheduler`] gives arguably the best speed/quality trade-off and can be run with as little +as 20 steps. + +```python +from diffusers import DPMSolverMultistepScheduler + +pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + +generator = torch.Generator(device="cuda").manual_seed(8) +image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0] +image +``` + +

+
+ +
+

+ +As you can see most images look very similar and are arguably of very similar quality. It often really depends on the specific use case which scheduler to choose. A good approach is always to run multiple different +schedulers to compare results. diff --git a/examples/community/README.md b/examples/community/README.md index fd6fff79c5..660f64098b 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -15,11 +15,14 @@ If a community doesn't work as expected, please open an issue and ping the autho | Long Prompt Weighting Stable Diffusion | **One** Stable Diffusion Pipeline without tokens length limit, and support parsing weighting in prompt. | [Long Prompt Weighting Stable Diffusion](#long-prompt-weighting-stable-diffusion) | - | [SkyTNT](https://github.com/SkyTNT) | | Speech to Image | Using automatic-speech-recognition to transcribe text and Stable Diffusion to generate images | [Speech to Image](#speech-to-image) | - | [Mikail Duzenli](https://github.com/MikailINTech) | Wild Card Stable Diffusion | Stable Diffusion Pipeline that supports prompts that contain wildcard terms (indicated by surrounding double underscores), with values instantiated randomly from a corresponding txt file or a dictionary of possible values | [Wildcard Stable Diffusion](#wildcard-stable-diffusion) | - | [Shyam Sudhakaran](https://github.com/shyamsn97) | -| Composable Stable Diffusion| Stable Diffusion Pipeline that supports prompts that contain "|" in prompts (as an AND condition) and weights (separated by "|" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) | +| [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain "|" in prompts (as an AND condition) and weights (separated by "|" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) | | Seed Resizing Stable Diffusion| Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | - | [Mark Rich](https://github.com/MarkRich) | | Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image| [Imagic Stable Diffusion](#imagic-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) | | Multilingual Stable Diffusion| Stable Diffusion Pipeline that supports prompts in 50 different languages. | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline) | - | [Juan Carlos Piñeros](https://github.com/juancopi81) | | Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting| [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Alex McKinney](https://github.com/vvvm23) | +| Text Based Inpainting Stable Diffusion | Stable Diffusion Inpainting Pipeline that enables passing a text prompt to generate the mask for inpainting| [Text Based Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Dhruv Karan](https://github.com/unography) | +| Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - |[Stuti R.](https://github.com/kingstut) | +| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) | @@ -342,9 +345,10 @@ out = pipe( ) ``` - ### Composable Stable diffusion +[Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) proposes conjunction and negation (negative prompts) operators for compositional generation with conditional diffusion models. + ```python import torch as th import numpy as np @@ -367,7 +371,7 @@ def dummy(images, **kwargs): pipe.safety_checker = dummy images = [] -generator = th.Generator("cuda").manual_seed(0) +generator = torch.Generator("cuda").manual_seed(0) seed = 0 prompt = "a forest | a camel" @@ -396,6 +400,7 @@ import requests from PIL import Image from io import BytesIO import torch +import os from diffusers import DiffusionPipeline, DDIMScheduler has_cuda = torch.cuda.is_available() device = torch.device('cpu' if not has_cuda else 'cuda') @@ -420,6 +425,7 @@ res = pipe.train( num_inference_steps=50, generator=generator) res = pipe(alpha=1) +os.makedirs("imagic", exist_ok=True) image = res.images[0] image.save('./imagic/imagic_image_alpha_1.png') res = pipe(alpha=1.5) @@ -596,7 +602,7 @@ For example, this could be used to place a logo on a shirt and make it blend sea import PIL import torch -from diffusers import StableDiffusionInpaintPipeline +from diffusers import DiffusionPipeline image_path = "./path-to-image.png" inner_image_path = "./path-to-inner-image.png" @@ -606,13 +612,120 @@ init_image = PIL.Image.open(image_path).convert("RGB").resize((512, 512)) inner_image = PIL.Image.open(inner_image_path).convert("RGBA").resize((512, 512)) mask_image = PIL.Image.open(mask_path).convert("RGB").resize((512, 512)) -pipe = StableDiffusionInpaintPipeline.from_pretrained( +pipe = DiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-inpainting", + custom_pipeline="img2img_inpainting", revision="fp16", - torch_dtype=torch.float16, + torch_dtype=torch.float16 ) pipe = pipe.to("cuda") prompt = "Your prompt here!" image = pipe(prompt=prompt, image=init_image, inner_image=inner_image, mask_image=mask_image).images[0] ``` + +![2 by 2 grid demonstrating image to image inpainting.](https://user-images.githubusercontent.com/44398246/203506577-ec303be4-887e-4ebd-a773-c83fcb3dd01a.png) + +### Text Based Inpainting Stable Diffusion + +Use a text prompt to generate the mask for the area to be inpainted. +Currently uses the CLIPSeg model for mask generation, then calls the standard Stable Diffusion Inpainting pipeline to perform the inpainting. + +```python +from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation +from diffusers import DiffusionPipeline + +from PIL import Image +import requests +from torch import autocast + +processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") +model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") + +pipe = DiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-inpainting", + custom_pipeline="text_inpainting", + segmentation_model=model, + segmentation_processor=processor +) +pipe = pipe.to("cuda") + + +url = "https://github.com/timojl/clipseg/blob/master/example_image.jpg?raw=true" +image = Image.open(requests.get(url, stream=True).raw).resize((512, 512)) +text = "a glass" # will mask out this text +prompt = "a cup" # the masked out region will be replaced with this + +with autocast("cuda"): + image = pipe(image=image, text=text, prompt=prompt).images[0] +``` + +### Bit Diffusion +Based https://arxiv.org/abs/2208.04202, this is used for diffusion on discrete data - eg, discreate image data, DNA sequence data. An unconditional discreate image can be generated like this: + +```python +from diffusers import DiffusionPipeline +pipe = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="bit_diffusion") +image = pipe().images[0] + +``` + +### Stable Diffusion with K Diffusion + +Make sure you have @crowsonkb's https://github.com/crowsonkb/k-diffusion installed: + +``` +pip install k-diffusion +``` + +You can use the community pipeline as follows: + +```python +from diffusers import DiffusionPipeline + +pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="sd_text2img_k_diffusion") +pipe = pipe.to("cuda") + +prompt = "an astronaut riding a horse on mars" +pipe.set_sampler("sample_heun") +generator = torch.Generator(device="cuda").manual_seed(seed) +image = pipe(prompt, generator=generator, num_inference_steps=20).images[0] + +image.save("./astronaut_heun_k_diffusion.png") +``` + +To make sure that K Diffusion and `diffusers` yield the same results: + +**Diffusers**: +```python +from diffusers import DiffusionPipeline, EulerDiscreteScheduler + +seed = 33 + +pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") +pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) +pipe = pipe.to("cuda") + +generator = torch.Generator(device="cuda").manual_seed(seed) +image = pipe(prompt, generator=generator, num_inference_steps=50).images[0] +``` + +![diffusers_euler](https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/k_diffusion/astronaut_euler.png) + +**K Diffusion**: +```python +from diffusers import DiffusionPipeline, EulerDiscreteScheduler + +seed = 33 + +pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="sd_text2img_k_diffusion") +pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) +pipe = pipe.to("cuda") + +pipe.set_sampler("sample_euler") +generator = torch.Generator(device="cuda").manual_seed(seed) +image = pipe(prompt, generator=generator, num_inference_steps=50).images[0] +``` + +![diffusers_euler](https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/k_diffusion/astronaut_euler_k_diffusion.png) + diff --git a/examples/community/bit_diffusion.py b/examples/community/bit_diffusion.py new file mode 100644 index 0000000000..956e25a7e5 --- /dev/null +++ b/examples/community/bit_diffusion.py @@ -0,0 +1,265 @@ +from typing import Optional, Tuple, Union + +import torch + +from diffusers import DDIMScheduler, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers.pipeline_utils import ImagePipelineOutput +from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput +from diffusers.schedulers.scheduling_ddpm import DDPMSchedulerOutput +from einops import rearrange, reduce + + +BITS = 8 + + +# convert to bit representations and back taken from https://github.com/lucidrains/bit-diffusion/blob/main/bit_diffusion/bit_diffusion.py +def decimal_to_bits(x, bits=BITS): + """expects image tensor ranging from 0 to 1, outputs bit tensor ranging from -1 to 1""" + device = x.device + + x = (x * 255).int().clamp(0, 255) + + mask = 2 ** torch.arange(bits - 1, -1, -1, device=device) + mask = rearrange(mask, "d -> d 1 1") + x = rearrange(x, "b c h w -> b c 1 h w") + + bits = ((x & mask) != 0).float() + bits = rearrange(bits, "b c d h w -> b (c d) h w") + bits = bits * 2 - 1 + return bits + + +def bits_to_decimal(x, bits=BITS): + """expects bits from -1 to 1, outputs image tensor from 0 to 1""" + device = x.device + + x = (x > 0).int() + mask = 2 ** torch.arange(bits - 1, -1, -1, device=device, dtype=torch.int32) + + mask = rearrange(mask, "d -> d 1 1") + x = rearrange(x, "b (c d) h w -> b c d h w", d=8) + dec = reduce(x * mask, "b c d h w -> b c h w", "sum") + return (dec / 255).clamp(0.0, 1.0) + + +# modified scheduler step functions for clamping the predicted x_0 between -bit_scale and +bit_scale +def ddim_bit_scheduler_step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + eta: float = 0.0, + use_clipped_model_output: bool = True, + generator=None, + return_dict: bool = True, +) -> Union[DDIMSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + eta (`float`): weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`): TODO + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class + Returns: + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + + # 4. Clip "predicted x_0" + scale = self.bit_scale + if self.config.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -scale, scale) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the model_output is always re-derived from the clipped x_0 in Glide + model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if eta > 0: + # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 + device = model_output.device if torch.is_tensor(model_output) else "cpu" + noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device) + variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise + + prev_sample = prev_sample + variance + + if not return_dict: + return (prev_sample,) + + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + +def ddpm_bit_scheduler_step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + prediction_type="epsilon", + generator=None, + return_dict: bool = True, +) -> Union[DDPMSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + prediction_type (`str`, default `epsilon`): + indicates whether the model predicts the noise (epsilon), or the samples (`sample`). + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class + Returns: + [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + t = timestep + + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif prediction_type == "sample": + pred_original_sample = model_output + else: + raise ValueError(f"Unsupported prediction_type {prediction_type}.") + + # 3. Clip "predicted x_0" + scale = self.bit_scale + if self.config.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -scale, scale) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t + current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample + + # 6. Add noise + variance = 0 + if t > 0: + noise = torch.randn( + model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator + ).to(model_output.device) + variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise + + pred_prev_sample = pred_prev_sample + variance + + if not return_dict: + return (pred_prev_sample,) + + return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) + + +class BitDiffusion(DiffusionPipeline): + def __init__( + self, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, DDPMScheduler], + bit_scale: Optional[float] = 1.0, + ): + super().__init__() + self.bit_scale = bit_scale + self.scheduler.step = ( + ddim_bit_scheduler_step if isinstance(scheduler, DDIMScheduler) else ddpm_bit_scheduler_step + ) + + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + height: Optional[int] = 256, + width: Optional[int] = 256, + num_inference_steps: Optional[int] = 50, + generator: Optional[torch.Generator] = None, + batch_size: Optional[int] = 1, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[Tuple, ImagePipelineOutput]: + latents = torch.randn( + (batch_size, self.unet.in_channels, height, width), + generator=generator, + ) + latents = decimal_to_bits(latents) * self.bit_scale + latents = latents.to(self.device) + + self.scheduler.set_timesteps(num_inference_steps) + + for t in self.progress_bar(self.scheduler.timesteps): + # predict the noise residual + noise_pred = self.unet(latents, t).sample + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + image = bits_to_decimal(latents) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/examples/community/clip_guided_stable_diffusion.py b/examples/community/clip_guided_stable_diffusion.py index 14d9ee6322..7a319bddf0 100644 --- a/examples/community/clip_guided_stable_diffusion.py +++ b/examples/community/clip_guided_stable_diffusion.py @@ -78,7 +78,12 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): ) self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std) - self.make_cutouts = MakeCutouts(feature_extractor.size) + cut_out_size = ( + feature_extractor.size + if isinstance(feature_extractor.size, int) + else feature_extractor.size["shortest_edge"] + ) + self.make_cutouts = MakeCutouts(cut_out_size) set_requires_grad(self.text_encoder, False) set_requires_grad(self.clip_model, False) diff --git a/examples/community/imagic_stable_diffusion.py b/examples/community/imagic_stable_diffusion.py index 0c95fb4358..65966b4830 100644 --- a/examples/community/imagic_stable_diffusion.py +++ b/examples/community/imagic_stable_diffusion.py @@ -18,17 +18,38 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from diffusers.utils import logging + +# TODO: remove and import from diffusers.utils when the new version of diffusers is released +from packaging import version from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } +else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } +# ------------------------------------------------------------------------------ + logger = logging.get_logger(__name__) # pylint: disable=invalid-name def preprocess(image): w, h = image.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) diff --git a/examples/community/img2img_inpainting.py b/examples/community/img2img_inpainting.py index f7a107136d..3fa7db13a4 100644 --- a/examples/community/img2img_inpainting.py +++ b/examples/community/img2img_inpainting.py @@ -110,7 +110,7 @@ class ImageToImageInpaintingPipeline(DiffusionPipeline): scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None: - logger.warn( + logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" diff --git a/examples/community/interpolate_stable_diffusion.py b/examples/community/interpolate_stable_diffusion.py index 761aaeca69..4d7a73f5ba 100644 --- a/examples/community/interpolate_stable_diffusion.py +++ b/examples/community/interpolate_stable_diffusion.py @@ -101,7 +101,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline): scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None: - logger.warn( + logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py index e4ee7bf3c6..0e7dc9e1ed 100644 --- a/examples/community/lpw_stable_diffusion.py +++ b/examples/community/lpw_stable_diffusion.py @@ -13,9 +13,31 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from diffusers.utils import deprecate, is_accelerate_available, logging + +# TODO: remove and import from diffusers.utils when the new version of diffusers is released +from packaging import version from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } +else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } +# ------------------------------------------------------------------------------ + + logger = logging.get_logger(__name__) # pylint: disable=invalid-name re_attention = re.compile( @@ -358,7 +380,7 @@ def get_weighted_text_embeddings( def preprocess_image(image): w, h = image.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) @@ -369,7 +391,7 @@ def preprocess_mask(mask): mask = mask.convert("L") w, h = mask.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST) + mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"]) mask = np.array(mask).astype(np.float32) / 255.0 mask = np.tile(mask, (4, 1, 1)) mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? @@ -447,7 +469,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline): scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None: - logger.warn( + logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" diff --git a/examples/community/lpw_stable_diffusion_onnx.py b/examples/community/lpw_stable_diffusion_onnx.py index 12e306a612..577772b9c3 100644 --- a/examples/community/lpw_stable_diffusion_onnx.py +++ b/examples/community/lpw_stable_diffusion_onnx.py @@ -11,9 +11,30 @@ from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from diffusers.utils import logging + +# TODO: remove and import from diffusers.utils when the new version of diffusers is released +from packaging import version from transformers import CLIPFeatureExtractor, CLIPTokenizer +if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } +else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } +# ------------------------------------------------------------------------------ + logger = logging.get_logger(__name__) # pylint: disable=invalid-name re_attention = re.compile( @@ -365,7 +386,7 @@ def get_weighted_text_embeddings( def preprocess_image(image): w, h = image.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) return 2.0 * image - 1.0 @@ -375,7 +396,7 @@ def preprocess_mask(mask): mask = mask.convert("L") w, h = mask.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST) + mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"]) mask = np.array(mask).astype(np.float32) / 255.0 mask = np.tile(mask, (4, 1, 1)) mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? diff --git a/examples/community/multilingual_stable_diffusion.py b/examples/community/multilingual_stable_diffusion.py index c71c1f10c5..19974d6df0 100644 --- a/examples/community/multilingual_stable_diffusion.py +++ b/examples/community/multilingual_stable_diffusion.py @@ -113,7 +113,7 @@ class MultilingualStableDiffusion(DiffusionPipeline): scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None: - logger.warn( + logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" diff --git a/examples/community/sd_text2img_k_diffusion.py b/examples/community/sd_text2img_k_diffusion.py new file mode 100755 index 0000000000..9592f7879f --- /dev/null +++ b/examples/community/sd_text2img_k_diffusion.py @@ -0,0 +1,479 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +from typing import Callable, List, Optional, Union + +import torch + +from diffusers import LMSDiscreteScheduler +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.utils import is_accelerate_available, logging +from k_diffusion.external import CompVisDenoiser + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class ModelWrapper: + def __init__(self, model, alphas_cumprod): + self.model = model + self.alphas_cumprod = alphas_cumprod + + def apply_model(self, *args, **kwargs): + return self.model(*args, **kwargs).sample + + +class StableDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae, + text_encoder, + tokenizer, + unet, + scheduler, + safety_checker, + feature_extractor, + ): + super().__init__() + + if safety_checker is None: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + # get correct sigmas from LMS + scheduler = LMSDiscreteScheduler.from_config(scheduler.config) + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + model = ModelWrapper(unet, scheduler.alphas_cumprod) + self.k_diffusion_model = CompVisDenoiser(model) + + def set_sampler(self, scheduler_type: str): + library = importlib.import_module("k_diffusion") + sampling = getattr(library, "sampling") + self.sampler = getattr(sampling, scheduler_type) + + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def check_inputs(self, prompt, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // 8, width // 8) + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = True + if guidance_scale <= 1.0: + raise ValueError("has to use guidance_scale") + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=text_embeddings.device) + sigmas = self.scheduler.sigmas + + # 5. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + latents = latents * sigmas[0] + self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device) + self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device) + + def model_fn(x, t): + latent_model_input = torch.cat([x] * 2) + + noise_pred = self.k_diffusion_model(latent_model_input, t, encoder_hidden_states=text_embeddings) + + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + return noise_pred + + latents = self.sampler(model_fn, latents, sigmas) + + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/examples/community/speech_to_image_diffusion.py b/examples/community/speech_to_image_diffusion.py index 1a9d296e81..17bc08e3c2 100644 --- a/examples/community/speech_to_image_diffusion.py +++ b/examples/community/speech_to_image_diffusion.py @@ -42,7 +42,7 @@ class SpeechToImagePipeline(DiffusionPipeline): super().__init__() if safety_checker is None: - logger.warn( + logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" diff --git a/examples/community/text_inpainting.py b/examples/community/text_inpainting.py new file mode 100644 index 0000000000..a4368f8b43 --- /dev/null +++ b/examples/community/text_inpainting.py @@ -0,0 +1,320 @@ +from typing import Callable, List, Optional, Union + +import torch + +import PIL +from diffusers.configuration_utils import FrozenDict +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from diffusers.utils import deprecate, is_accelerate_available, logging +from transformers import ( + CLIPFeatureExtractor, + CLIPSegForImageSegmentation, + CLIPSegProcessor, + CLIPTextModel, + CLIPTokenizer, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class TextInpainting(DiffusionPipeline): + r""" + Pipeline for text based inpainting using Stable Diffusion. + Uses CLIPSeg to get a mask from the given text, then calls the Inpainting pipeline with the generated mask + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + segmentation_model ([`CLIPSegForImageSegmentation`]): + CLIPSeg Model to generate mask from the given text. Please refer to the [model card]() for details. + segmentation_processor ([`CLIPSegProcessor`]): + CLIPSeg processor to get image, text features to translate prompt to English, if necessary. Please refer to the + [model card](https://huggingface.co/docs/transformers/model_doc/clipseg) for details. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + segmentation_model: CLIPSegForImageSegmentation, + segmentation_processor: CLIPSegProcessor, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration" + " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" + " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" + " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" + " Hub, it would be very nice if you could open a Pull request for the" + " `scheduler/scheduler_config.json` file" + ) + deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["skip_prk_steps"] = True + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + self.register_modules( + segmentation_model=segmentation_model, + segmentation_processor=segmentation_processor, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device("cuda") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.FloatTensor, PIL.Image.Image], + text: str, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + text (`str``): + The text to use to generate the mask. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # We use the input text to generate the mask + inputs = self.segmentation_processor( + text=[text], images=[image], padding="max_length", return_tensors="pt" + ).to(self.device) + outputs = self.segmentation_model(**inputs) + mask = torch.sigmoid(outputs.logits).cpu().detach().unsqueeze(-1).numpy() + mask_pil = self.numpy_to_pil(mask)[0].resize(image.size) + + # Run inpainting pipeline with the generated mask + inpainting_pipeline = StableDiffusionInpaintPipeline( + vae=self.vae, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + unet=self.unet, + scheduler=self.scheduler, + safety_checker=self.safety_checker, + feature_extractor=self.feature_extractor, + ) + return inpainting_pipeline( + prompt=prompt, + image=image, + mask_image=mask_pil, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + ) diff --git a/examples/community/wildcard_stable_diffusion.py b/examples/community/wildcard_stable_diffusion.py index 9ad0d8e9fa..282be8e48b 100644 --- a/examples/community/wildcard_stable_diffusion.py +++ b/examples/community/wildcard_stable_diffusion.py @@ -135,7 +135,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline): scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None: - logger.warn( + logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index 3c9d04abc2..e202126fbb 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -39,6 +39,8 @@ Now let's get our dataset. Download images from [here](https://drive.google.com/ And launch the training using +**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** + ```bash export MODEL_NAME="CompVis/stable-diffusion-v1-4" export INSTANCE_DIR="path-to-instance-images" @@ -92,7 +94,7 @@ accelerate launch train_dreambooth.py \ With the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU. -Install `bitsandbytes` with `pip install bitsandbytes` +To install `bitandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation). ```bash export MODEL_NAME="CompVis/stable-diffusion-v1-4" @@ -141,7 +143,7 @@ export INSTANCE_DIR="path-to-instance-images" export CLASS_DIR="path-to-class-images" export OUTPUT_DIR="path-to-save-model" -accelerate launch train_dreambooth.py \ +accelerate launch --mixed_precision="fp16" train_dreambooth.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --instance_data_dir=$INSTANCE_DIR \ --class_data_dir=$CLASS_DIR \ @@ -157,8 +159,7 @@ accelerate launch train_dreambooth.py \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=200 \ - --max_train_steps=800 \ - --mixed_precision=fp16 + --max_train_steps=800 ``` ### Fine-tune text encoder with the UNet. diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 610c18533b..331e3ae922 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -124,6 +124,7 @@ def parse_args(input_args=None): default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) + parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") parser.add_argument( "--gradient_accumulation_steps", type=int, @@ -187,12 +188,12 @@ def parse_args(input_args=None): parser.add_argument( "--mixed_precision", type=str, - default="no", + default=None, choices=["no", "fp16", "bf16"], help=( - "Whether to use mixed precision. Choose" - "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." - "and an Nvidia Ampere GPU." + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") @@ -472,7 +473,7 @@ def main(args): eps=args.adam_epsilon, ) - noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler") + noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler") train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, @@ -538,9 +539,9 @@ def main(args): ) weight_dtype = torch.float32 - if args.mixed_precision == "fp16": + if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": + elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 # Move text_encode and vae to gpu. @@ -603,23 +604,31 @@ def main(args): encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") if args.with_prior_preservation: - # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. - noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) - noise, noise_prior = torch.chunk(noise, 2, dim=0) + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) # Compute instance loss - loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() # Compute prior loss - prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: - loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") accelerator.backward(loss) if accelerator.sync_gradients: @@ -638,6 +647,17 @@ def main(args): progress_bar.update(1) global_step += 1 + if global_step % args.save_steps == 0: + if accelerator.is_main_process: + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + revision=args.revision, + ) + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + pipeline.save_pretrained(save_path) + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index 170ed384f1..cfe82e8f90 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -42,11 +42,13 @@ If you have already cloned the repo, then you won't need to go through these ste #### Hardware With `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with >30GB memory. +**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** + ```bash export MODEL_NAME="CompVis/stable-diffusion-v1-4" export dataset_name="lambdalabs/pokemon-blip-captions" -accelerate launch train_text_to_image.py \ +accelerate launch --mixed_precision="fp16" train_text_to_image.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --dataset_name=$dataset_name \ --use_ema \ @@ -54,7 +56,6 @@ accelerate launch train_text_to_image.py \ --train_batch_size=1 \ --gradient_accumulation_steps=4 \ --gradient_checkpointing \ - --mixed_precision="fp16" \ --max_train_steps=15000 \ --learning_rate=1e-05 \ --max_grad_norm=1 \ @@ -70,7 +71,7 @@ If you wish to use custom loading logic, you should modify the script, we have l export MODEL_NAME="CompVis/stable-diffusion-v1-4" export TRAIN_DIR="path_to_your_dataset" -accelerate launch train_text_to_image.py \ +accelerate launch --mixed_precision="fp16" train_text_to_image.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$TRAIN_DIR \ --use_ema \ @@ -78,7 +79,6 @@ accelerate launch train_text_to_image.py \ --train_batch_size=1 \ --gradient_accumulation_steps=4 \ --gradient_checkpointing \ - --mixed_precision="fp16" \ --max_train_steps=15000 \ --learning_rate=1e-05 \ --max_grad_norm=1 \ diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index cf7dac8933..1027b7a8ba 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -15,13 +15,12 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed from datasets import load_dataset -from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler -from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from huggingface_hub import HfFolder, Repository, whoami from torchvision import transforms from tqdm.auto import tqdm -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModel, CLIPTokenizer logger = get_logger(__name__) @@ -36,6 +35,13 @@ def parse_args(): required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) parser.add_argument( "--dataset_name", type=str, @@ -186,12 +192,12 @@ def parse_args(): parser.add_argument( "--mixed_precision", type=str, - default="no", + default=None, choices=["no", "fp16", "bf16"], help=( - "Whether to use mixed precision. Choose" - "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." - "and an Nvidia Ampere GPU." + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) parser.add_argument( @@ -335,10 +341,24 @@ def main(): os.makedirs(args.output_dir, exist_ok=True) # Load models and create wrapper for stable diffusion - tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") - text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") - unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision + ) + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, + ) # Freeze vae and text_encoder vae.requires_grad_(False) @@ -372,7 +392,7 @@ def main(): weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) - noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler") + noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler") # Get the datasets: you can either provide your own training and evaluation files (see below) # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). @@ -496,9 +516,9 @@ def main(): ) weight_dtype = torch.float32 - if args.mixed_precision == "fp16": + if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": + elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 # Move text_encode and vae to gpu. @@ -562,9 +582,17 @@ def main(): # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"])[0] + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + # Predict the noise residual and compute loss - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") # Gather the losses across all processes for logging (if we use distributed training). avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() @@ -600,14 +628,12 @@ def main(): if args.use_ema: ema_unet.copy_to(unet.parameters()) - pipeline = StableDiffusionPipeline( + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, text_encoder=text_encoder, vae=vae, unet=unet, - tokenizer=tokenizer, - scheduler=PNDMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler"), - safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), - feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + revision=args.revision, ) pipeline.save_pretrained(args.output_dir) diff --git a/examples/textual_inversion/README.md b/examples/textual_inversion/README.md index 2edf34cb49..3aeb6e50c7 100644 --- a/examples/textual_inversion/README.md +++ b/examples/textual_inversion/README.md @@ -47,6 +47,8 @@ Now let's get our dataset.Download 3-4 images from [here](https://drive.google.c And launch the training using +**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** + ```bash export MODEL_NAME="runwayml/stable-diffusion-v1-5" export DATA_DIR="path-to-dir-containing-images" diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index fc9380edcd..77ef350c51 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -16,24 +16,45 @@ import PIL from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed -from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler -from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from huggingface_hub import HfFolder, Repository, whoami + +# TODO: remove and import from diffusers.utils when the new version of diffusers is released +from packaging import version from PIL import Image from torchvision import transforms from tqdm.auto import tqdm -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModel, CLIPTokenizer + + +if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } +else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } +# ------------------------------------------------------------------------------ logger = get_logger(__name__) -def save_progress(text_encoder, placeholder_token_id, accelerator, args): +def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path): logger.info("Saving embeddings") learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id] learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()} - torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin")) + torch.save(learned_embeds_dict, save_path) def parse_args(): @@ -51,6 +72,13 @@ def parse_args(): required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) parser.add_argument( "--tokenizer_name", type=str, @@ -260,10 +288,10 @@ class TextualInversionDataset(Dataset): self._length = self.num_images * repeats self.interpolation = { - "linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, + "linear": PIL_INTERPOLATION["linear"], + "bilinear": PIL_INTERPOLATION["bilinear"], + "bicubic": PIL_INTERPOLATION["bicubic"], + "lanczos": PIL_INTERPOLATION["lanczos"], }[interpolation] self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small @@ -383,9 +411,21 @@ def main(): placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) # Load models and create wrapper for stable diffusion - text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") - unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, + ) # Resize the token embeddings as we are adding new special tokens to the tokenizer text_encoder.resize_token_embeddings(len(tokenizer)) @@ -419,7 +459,7 @@ def main(): eps=args.adam_epsilon, ) - noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler") + noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler") train_dataset = TextualInversionDataset( data_root=args.train_data_dir, @@ -510,9 +550,17 @@ def main(): encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + loss = F.mse_loss(model_pred, target, reduction="none").mean([1, 2, 3]).mean() accelerator.backward(loss) # Zero out the gradients for all token embeddings except the newly added @@ -534,7 +582,8 @@ def main(): progress_bar.update(1) global_step += 1 if global_step % args.save_steps == 0: - save_progress(text_encoder, placeholder_token_id, accelerator, args) + save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin") + save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -547,18 +596,18 @@ def main(): # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: - pipeline = StableDiffusionPipeline( + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, text_encoder=accelerator.unwrap_model(text_encoder), + tokenizer=tokenizer, vae=vae, unet=unet, - tokenizer=tokenizer, - scheduler=PNDMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler"), - safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), - feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + revision=args.revision, ) pipeline.save_pretrained(args.output_dir) # Also save the newly trained embeddings - save_progress(text_encoder, placeholder_token_id, accelerator, args) + save_path = os.path.join(args.output_dir, "learned_embeds.bin") + save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path) if args.push_to_hub: repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index be2b7ffb54..6406be8ad6 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -28,12 +28,33 @@ from flax import jax_utils from flax.training import train_state from flax.training.common_utils import shard from huggingface_hub import HfFolder, Repository, whoami + +# TODO: remove and import from diffusers.utils when the new version of diffusers is released +from packaging import version from PIL import Image from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed +if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } +else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } +# ------------------------------------------------------------------------------ + logger = logging.getLogger(__name__) @@ -246,10 +267,10 @@ class TextualInversionDataset(Dataset): self._length = self.num_images * repeats self.interpolation = { - "linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, + "linear": PIL_INTERPOLATION["linear"], + "bilinear": PIL_INTERPOLATION["bilinear"], + "bicubic": PIL_INTERPOLATION["bicubic"], + "lanczos": PIL_INTERPOLATION["lanczos"], }[interpolation] self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small diff --git a/examples/unconditional_image_generation/README.md b/examples/unconditional_image_generation/README.md index e9c461b482..dbb8491789 100644 --- a/examples/unconditional_image_generation/README.md +++ b/examples/unconditional_image_generation/README.md @@ -127,3 +127,24 @@ dataset.push_to_hub("name_of_your_dataset", private=True) and that's it! You can now train your model by simply setting the `--dataset_name` argument to the name of your dataset on the hub. More on this can also be found in [this blog post](https://huggingface.co/blog/image-search-datasets). + +#### Use ONNXRuntime to accelerate training + +In order to leverage onnxruntime to accelerate training, please use train_unconditional_ort.py + +The command to train a DDPM UNet model on the Oxford Flowers dataset with onnxruntime: + +```bash +accelerate launch train_unconditional_ort.py \ + --dataset_name="huggan/flowers-102-categories" \ + --resolution=64 \ + --output_dir="ddpm-ema-flowers-64" \ + --train_batch_size=16 \ + --num_epochs=1 \ + --gradient_accumulation_steps=1 \ + --learning_rate=1e-4 \ + --lr_warmup_steps=500 \ + --mixed_precision=fp16 + ``` + +Please contact Prathik Rao (prathikr), Sunghoon Choi (hanbitmyths), Ashwini Khade (askhade), or Peng Wang (pengwa) on github with any questions. \ No newline at end of file diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 54a94d98b5..fc5be82b6a 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -194,9 +194,10 @@ def parse_args(): ) parser.add_argument( - "--predict_epsilon", - action="store_true", - default=True, + "--prediction_type", + type=str, + default="epsilon", + choices=["epsilon", "sample"], help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.", ) @@ -256,13 +257,13 @@ def main(args): "UpBlock2D", ), ) - accepts_predict_epsilon = "predict_epsilon" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys()) + accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys()) - if accepts_predict_epsilon: + if accepts_prediction_type: noise_scheduler = DDPMScheduler( num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule, - predict_epsilon=args.predict_epsilon, + prediction_type=args.prediction_type, ) else: noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule) @@ -319,7 +320,12 @@ def main(args): num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay) + ema_model = EMAModel( + accelerator.unwrap_model(model), + inv_gamma=args.ema_inv_gamma, + power=args.ema_power, + max_value=args.ema_max_decay, + ) # Handle the repository creation if accelerator.is_main_process: @@ -365,9 +371,9 @@ def main(args): # Predict the noise residual model_output = model(noisy_images, timesteps).sample - if args.predict_epsilon: + if args.prediction_type == "epsilon": loss = F.mse_loss(model_output, noise) # this could have different weights! - else: + elif args.prediction_type == "sample": alpha_t = _extract_into_tensor( noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1) ) @@ -376,6 +382,8 @@ def main(args): model_output, clean_images, reduction="none" ) # use SNR weighting from distillation paper loss = loss.mean() + else: + raise ValueError(f"Unsupported prediction type: {args.prediction_type}") accelerator.backward(loss) diff --git a/examples/unconditional_image_generation/train_unconditional_ort.py b/examples/unconditional_image_generation/train_unconditional_ort.py new file mode 100644 index 0000000000..8259c835fc --- /dev/null +++ b/examples/unconditional_image_generation/train_unconditional_ort.py @@ -0,0 +1,251 @@ +import argparse +import math +import os + +import torch +import torch.nn.functional as F + +from accelerate import Accelerator +from accelerate.logging import get_logger +from datasets import load_dataset +from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel +from diffusers.hub_utils import init_git_repo, push_to_hub +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from onnxruntime.training.ortmodule import ORTModule +from torchvision.transforms import ( + CenterCrop, + Compose, + InterpolationMode, + Normalize, + RandomHorizontalFlip, + Resize, + ToTensor, +) +from tqdm.auto import tqdm + + +logger = get_logger(__name__) + + +def main(args): + logging_dir = os.path.join(args.output_dir, args.logging_dir) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with="tensorboard", + logging_dir=logging_dir, + ) + + model = UNet2DModel( + sample_size=args.resolution, + in_channels=3, + out_channels=3, + layers_per_block=2, + block_out_channels=(128, 128, 256, 256, 512, 512), + down_block_types=( + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "AttnDownBlock2D", + "DownBlock2D", + ), + up_block_types=( + "UpBlock2D", + "AttnUpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + ), + ) + model = ORTModule(model) + noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt") + optimizer = torch.optim.AdamW( + model.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + augmentations = Compose( + [ + Resize(args.resolution, interpolation=InterpolationMode.BILINEAR), + CenterCrop(args.resolution), + RandomHorizontalFlip(), + ToTensor(), + Normalize([0.5], [0.5]), + ] + ) + + if args.dataset_name is not None: + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + use_auth_token=True if args.use_auth_token else None, + split="train", + ) + else: + dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train") + + def transforms(examples): + images = [augmentations(image.convert("RGB")) for image in examples["image"]] + return {"input": images} + + dataset.set_transform(transforms) + train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps, + num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps, + ) + + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + + ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay) + + if args.push_to_hub: + repo = init_git_repo(args, at_init=True) + + if accelerator.is_main_process: + run = os.path.split(__file__)[-1].split(".")[0] + accelerator.init_trackers(run) + + global_step = 0 + for epoch in range(args.num_epochs): + model.train() + progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process) + progress_bar.set_description(f"Epoch {epoch}") + for step, batch in enumerate(train_dataloader): + clean_images = batch["input"] + # Sample noise that we'll add to the images + noise = torch.randn(clean_images.shape).to(clean_images.device) + bsz = clean_images.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device + ).long() + + # Add noise to the clean images according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps) + + with accelerator.accumulate(model): + # Predict the noise residual + noise_pred = model(noisy_images, timesteps, return_dict=True)[0] + loss = F.mse_loss(noise_pred, noise) + accelerator.backward(loss) + + accelerator.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + lr_scheduler.step() + if args.use_ema: + ema_model.step(model) + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} + if args.use_ema: + logs["ema_decay"] = ema_model.decay + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + progress_bar.close() + + accelerator.wait_for_everyone() + + # Generate sample images for visual inspection + if accelerator.is_main_process: + if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1: + pipeline = DDPMPipeline( + unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model), + scheduler=noise_scheduler, + ) + + generator = torch.manual_seed(0) + # run pipeline in inference (sample random noise and denoise) + images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy").images + + # denormalize the images and save to tensorboard + images_processed = (images * 255).round().astype("uint8") + accelerator.trackers[0].writer.add_images( + "test_samples", images_processed.transpose(0, 3, 1, 2), epoch + ) + + if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: + # save the model + if args.push_to_hub: + push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False) + else: + pipeline.save_pretrained(args.output_dir) + accelerator.wait_for_everyone() + + accelerator.end_training() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument("--local_rank", type=int, default=-1) + parser.add_argument("--dataset_name", type=str, default=None) + parser.add_argument("--dataset_config_name", type=str, default=None) + parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.") + parser.add_argument("--output_dir", type=str, default="ddpm-model-64") + parser.add_argument("--overwrite_output_dir", action="store_true") + parser.add_argument("--cache_dir", type=str, default=None) + parser.add_argument("--resolution", type=int, default=64) + parser.add_argument("--train_batch_size", type=int, default=16) + parser.add_argument("--eval_batch_size", type=int, default=16) + parser.add_argument("--num_epochs", type=int, default=100) + parser.add_argument("--save_images_epochs", type=int, default=10) + parser.add_argument("--save_model_epochs", type=int, default=10) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--learning_rate", type=float, default=1e-4) + parser.add_argument("--lr_scheduler", type=str, default="cosine") + parser.add_argument("--lr_warmup_steps", type=int, default=500) + parser.add_argument("--adam_beta1", type=float, default=0.95) + parser.add_argument("--adam_beta2", type=float, default=0.999) + parser.add_argument("--adam_weight_decay", type=float, default=1e-6) + parser.add_argument("--adam_epsilon", type=float, default=1e-08) + parser.add_argument("--use_ema", action="store_true", default=True) + parser.add_argument("--ema_inv_gamma", type=float, default=1.0) + parser.add_argument("--ema_power", type=float, default=3 / 4) + parser.add_argument("--ema_max_decay", type=float, default=0.9999) + parser.add_argument("--push_to_hub", action="store_true") + parser.add_argument("--use_auth_token", action="store_true") + parser.add_argument("--hub_token", type=str, default=None) + parser.add_argument("--hub_model_id", type=str, default=None) + parser.add_argument("--hub_private_repo", action="store_true") + parser.add_argument("--logging_dir", type=str, default="logs") + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("You must specify either a dataset name from the hub or a train data directory.") + + main(args) diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py index 375b12b6f8..2d354df938 100644 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -211,6 +211,7 @@ def create_unet_diffusers_config(original_config): """ Creates a config for the diffusers based on the config of the LDM model. """ + model_params = original_config.model.params unet_params = original_config.model.params.unet_config.params block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] @@ -230,7 +231,7 @@ def create_unet_diffusers_config(original_config): resolution //= 2 config = dict( - sample_size=unet_params.image_size, + sample_size=model_params.image_size, in_channels=unet_params.in_channels, out_channels=unet_params.out_channels, down_block_types=tuple(down_block_types), diff --git a/scripts/convert_stable_diffusion_checkpoint_to_onnx.py b/scripts/convert_stable_diffusion_checkpoint_to_onnx.py index f0e0b178af..26d3d5618f 100644 --- a/scripts/convert_stable_diffusion_checkpoint_to_onnx.py +++ b/scripts/convert_stable_diffusion_checkpoint_to_onnx.py @@ -215,8 +215,10 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F ) del pipeline.safety_checker safety_checker = OnnxRuntimeModel.from_pretrained(output_path / "safety_checker") + feature_extractor = pipeline.feature_extractor else: safety_checker = None + feature_extractor = None onnx_pipeline = OnnxStableDiffusionPipeline( vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"), @@ -226,7 +228,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"), scheduler=pipeline.scheduler, safety_checker=safety_checker, - feature_extractor=pipeline.feature_extractor, + feature_extractor=feature_extractor, + requires_safety_checker=safety_checker is not None, ) onnx_pipeline.save_pretrained(output_path) diff --git a/scripts/convert_versatile_diffusion_to_diffusers.py b/scripts/convert_versatile_diffusion_to_diffusers.py new file mode 100644 index 0000000000..86fb0e7b4c --- /dev/null +++ b/scripts/convert_versatile_diffusion_to_diffusers.py @@ -0,0 +1,791 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Conversion script for the Versatile Stable Diffusion checkpoints. """ + +import argparse +from argparse import Namespace + +import torch + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + UNet2DConditionModel, + VersatileDiffusionPipeline, +) +from diffusers.pipelines.versatile_diffusion.modeling_text_unet import UNetFlatConditionModel +from transformers import ( + CLIPFeatureExtractor, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + + +SCHEDULER_CONFIG = Namespace( + **{ + "beta_linear_start": 0.00085, + "beta_linear_end": 0.012, + "timesteps": 1000, + "scale_factor": 0.18215, + } +) + +IMAGE_UNET_CONFIG = Namespace( + **{ + "input_channels": 4, + "model_channels": 320, + "output_channels": 4, + "num_noattn_blocks": [2, 2, 2, 2], + "channel_mult": [1, 2, 4, 4], + "with_attn": [True, True, True, False], + "num_heads": 8, + "context_dim": 768, + "use_checkpoint": True, + } +) + +TEXT_UNET_CONFIG = Namespace( + **{ + "input_channels": 768, + "model_channels": 320, + "output_channels": 768, + "num_noattn_blocks": [2, 2, 2, 2], + "channel_mult": [1, 2, 4, 4], + "second_dim": [4, 4, 4, 4], + "with_attn": [True, True, True, False], + "num_heads": 8, + "context_dim": 768, + "use_checkpoint": True, + } +) + +AUTOENCODER_CONFIG = Namespace( + **{ + "double_z": True, + "z_channels": 4, + "resolution": 256, + "in_channels": 3, + "out_ch": 3, + "ch": 128, + "ch_mult": [1, 2, 4, 4], + "num_res_blocks": 2, + "attn_resolutions": [], + "dropout": 0.0, + } +) + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "query.weight") + new_item = new_item.replace("q.bias", "query.bias") + + new_item = new_item.replace("k.weight", "key.weight") + new_item = new_item.replace("k.bias", "key.bias") + + new_item = new_item.replace("v.weight", "value.weight") + new_item = new_item.replace("v.bias", "value.bias") + + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming + to them. It splits attention layers, and takes into account additional replacements + that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif path["old"] in old_checkpoint: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_image_unet_diffusers_config(unet_params): + """ + Creates a config for the diffusers based on the config of the VD model. + """ + + block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if unet_params.with_attn[i] else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if unet_params.with_attn[-i - 1] else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + if not all(n == unet_params.num_noattn_blocks[0] for n in unet_params.num_noattn_blocks): + raise ValueError("Not all num_res_blocks are equal, which is not supported in this script.") + + config = dict( + sample_size=None, + in_channels=unet_params.input_channels, + out_channels=unet_params.output_channels, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + layers_per_block=unet_params.num_noattn_blocks[0], + cross_attention_dim=unet_params.context_dim, + attention_head_dim=unet_params.num_heads, + ) + + return config + + +def create_text_unet_diffusers_config(unet_params): + """ + Creates a config for the diffusers based on the config of the VD model. + """ + + block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlockFlat" if unet_params.with_attn[i] else "DownBlockFlat" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlockFlat" if unet_params.with_attn[-i - 1] else "UpBlockFlat" + up_block_types.append(block_type) + resolution //= 2 + + if not all(n == unet_params.num_noattn_blocks[0] for n in unet_params.num_noattn_blocks): + raise ValueError("Not all num_res_blocks are equal, which is not supported in this script.") + + config = dict( + sample_size=None, + in_channels=(unet_params.input_channels, 1, 1), + out_channels=(unet_params.output_channels, 1, 1), + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + layers_per_block=unet_params.num_noattn_blocks[0], + cross_attention_dim=unet_params.context_dim, + attention_head_dim=unet_params.num_heads, + ) + + return config + + +def create_vae_diffusers_config(vae_params): + """ + Creates a config for the diffusers based on the config of the VD model. + """ + + block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = dict( + sample_size=vae_params.resolution, + in_channels=vae_params.in_channels, + out_channels=vae_params.out_ch, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + latent_channels=vae_params.z_channels, + layers_per_block=vae_params.num_res_blocks, + ) + return config + + +def create_diffusers_scheduler(original_config): + schedular = DDIMScheduler( + num_train_timesteps=original_config.model.params.timesteps, + beta_start=original_config.model.params.linear_start, + beta_end=original_config.model.params.linear_end, + beta_schedule="scaled_linear", + ) + return schedular + + +def convert_vd_unet_checkpoint(checkpoint, config, unet_key, extract_ema=False): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100: + print("Checkpoint has both EMA and non-EMA weights.") + if extract_ema: + print( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + print( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["model.diffusion_model.time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["model.diffusion_model.time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["model.diffusion_model.time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["model.diffusion_model.time_embed.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + elif f"input_blocks.{i}.0.weight" in unet_state_dict: + # text_unet uses linear layers in place of downsamplers + shape = unet_state_dict[f"input_blocks.{i}.0.weight"].shape + if shape[0] != shape[1]: + continue + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if ["conv.weight", "conv.bias"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + elif f"output_blocks.{i}.1.weight" in unet_state_dict: + # text_unet uses linear layers in place of upsamplers + shape = unet_state_dict[f"output_blocks.{i}.1.weight"].shape + if shape[0] != shape[1]: + continue + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.weight"] = unet_state_dict.pop( + f"output_blocks.{i}.1.weight" + ) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.bias"] = unet_state_dict.pop( + f"output_blocks.{i}.1.bias" + ) + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + elif f"output_blocks.{i}.2.weight" in unet_state_dict: + # text_unet uses linear layers in place of upsamplers + shape = unet_state_dict[f"output_blocks.{i}.2.weight"].shape + if shape[0] != shape[1]: + continue + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.weight"] = unet_state_dict.pop( + f"output_blocks.{i}.2.weight" + ) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.bias"] = unet_state_dict.pop( + f"output_blocks.{i}.2.bias" + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + return new_checkpoint + + +def convert_vd_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + keys = list(checkpoint.keys()) + for key in keys: + vae_state_dict[key] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--unet_checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--vae_checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--optimus_checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--scheduler_type", + default="pndm", + type=str, + help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancest', 'dpm']", + ) + parser.add_argument( + "--extract_ema", + action="store_true", + help=( + "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" + " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" + " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." + ), + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + + args = parser.parse_args() + + scheduler_config = SCHEDULER_CONFIG + + num_train_timesteps = scheduler_config.timesteps + beta_start = scheduler_config.beta_linear_start + beta_end = scheduler_config.beta_linear_end + if args.scheduler_type == "pndm": + scheduler = PNDMScheduler( + beta_end=beta_end, + beta_schedule="scaled_linear", + beta_start=beta_start, + num_train_timesteps=num_train_timesteps, + skip_prk_steps=True, + steps_offset=1, + ) + elif args.scheduler_type == "lms": + scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear") + elif args.scheduler_type == "euler": + scheduler = EulerDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear") + elif args.scheduler_type == "euler-ancestral": + scheduler = EulerAncestralDiscreteScheduler( + beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear" + ) + elif args.scheduler_type == "dpm": + scheduler = DPMSolverMultistepScheduler( + beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear" + ) + elif args.scheduler_type == "ddim": + scheduler = DDIMScheduler( + beta_start=beta_start, + beta_end=beta_end, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + steps_offset=1, + ) + else: + raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!") + + # Convert the UNet2DConditionModel models. + if args.unet_checkpoint_path is not None: + # image UNet + image_unet_config = create_image_unet_diffusers_config(IMAGE_UNET_CONFIG) + checkpoint = torch.load(args.unet_checkpoint_path) + converted_image_unet_checkpoint = convert_vd_unet_checkpoint( + checkpoint, image_unet_config, unet_key="model.diffusion_model.unet_image.", extract_ema=args.extract_ema + ) + image_unet = UNet2DConditionModel(**image_unet_config) + image_unet.load_state_dict(converted_image_unet_checkpoint) + + # text UNet + text_unet_config = create_text_unet_diffusers_config(TEXT_UNET_CONFIG) + converted_text_unet_checkpoint = convert_vd_unet_checkpoint( + checkpoint, text_unet_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema + ) + text_unet = UNetFlatConditionModel(**text_unet_config) + text_unet.load_state_dict(converted_text_unet_checkpoint) + + # Convert the VAE model. + if args.vae_checkpoint_path is not None: + vae_config = create_vae_diffusers_config(AUTOENCODER_CONFIG) + checkpoint = torch.load(args.vae_checkpoint_path) + converted_vae_checkpoint = convert_vd_vae_checkpoint(checkpoint, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + image_feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-large-patch14") + text_encoder = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") + image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") + + pipe = VersatileDiffusionPipeline( + scheduler=scheduler, + tokenizer=tokenizer, + image_feature_extractor=image_feature_extractor, + text_encoder=text_encoder, + image_encoder=image_encoder, + image_unet=image_unet, + text_unet=text_unet, + vae=vae, + ) + pipe.save_pretrained(args.dump_path) diff --git a/scripts/convert_vq_diffusion_to_diffusers.py b/scripts/convert_vq_diffusion_to_diffusers.py index ae105e3036..85db67844a 100644 --- a/scripts/convert_vq_diffusion_to_diffusers.py +++ b/scripts/convert_vq_diffusion_to_diffusers.py @@ -39,8 +39,8 @@ import torch import yaml from accelerate import init_empty_weights, load_checkpoint_and_dispatch -from diffusers import VQDiffusionPipeline, VQDiffusionScheduler, VQModel -from diffusers.models.attention import Transformer2DModel +from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel +from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings from transformers import CLIPTextModel, CLIPTokenizer from yaml.loader import FullLoader @@ -826,6 +826,20 @@ if __name__ == "__main__": transformer_model, checkpoint ) + # classifier free sampling embeddings interlude + + # The learned embeddings are stored on the transformer in the original VQ-diffusion. We store them on a separate + # model, so we pull them off the checkpoint before the checkpoint is deleted. + + learnable_classifier_free_sampling_embeddings = diffusion_config.params.learnable_cf + + if learnable_classifier_free_sampling_embeddings: + learned_classifier_free_sampling_embeddings_embeddings = checkpoint["transformer.empty_text_embed"] + else: + learned_classifier_free_sampling_embeddings_embeddings = None + + # done classifier free sampling embeddings interlude + with tempfile.NamedTemporaryFile() as diffusers_transformer_checkpoint_file: torch.save(diffusers_transformer_checkpoint, diffusers_transformer_checkpoint_file.name) del diffusers_transformer_checkpoint @@ -871,6 +885,31 @@ if __name__ == "__main__": # done scheduler + # learned classifier free sampling embeddings + + with init_empty_weights(): + learned_classifier_free_sampling_embeddings_model = LearnedClassifierFreeSamplingEmbeddings( + learnable_classifier_free_sampling_embeddings, + hidden_size=text_encoder_model.config.hidden_size, + length=tokenizer_model.model_max_length, + ) + + learned_classifier_free_sampling_checkpoint = { + "embeddings": learned_classifier_free_sampling_embeddings_embeddings.float() + } + + with tempfile.NamedTemporaryFile() as learned_classifier_free_sampling_checkpoint_file: + torch.save(learned_classifier_free_sampling_checkpoint, learned_classifier_free_sampling_checkpoint_file.name) + del learned_classifier_free_sampling_checkpoint + del learned_classifier_free_sampling_embeddings_embeddings + load_checkpoint_and_dispatch( + learned_classifier_free_sampling_embeddings_model, + learned_classifier_free_sampling_checkpoint_file.name, + device_map="auto", + ) + + # done learned classifier free sampling embeddings + print(f"saving VQ diffusion model, path: {args.dump_path}") pipe = VQDiffusionPipeline( @@ -878,6 +917,7 @@ if __name__ == "__main__": transformer=transformer_model, tokenizer=tokenizer_model, text_encoder=text_encoder_model, + learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings_model, scheduler=scheduler_model, ) pipe.save_pretrained(args.dump_path) diff --git a/setup.py b/setup.py index 1bb6af4b10..4ebec86927 100644 --- a/setup.py +++ b/setup.py @@ -78,7 +78,7 @@ from setuptools import find_packages, setup # 1. all dependencies should be listed here with their version requirements if any # 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py _deps = [ - "Pillow<10.0", # keep the PIL.Image.Resampling deprecation away + "Pillow", # keep the PIL.Image.Resampling deprecation away "accelerate>=0.11.0", "black==22.8", "datasets", @@ -97,6 +97,8 @@ _deps = [ "pytest", "pytest-timeout", "pytest-xdist", + "safetensors", + "sentencepiece>=0.1.91,!=0.1.92", "scipy", "regex!=2019.12.17", "requests", @@ -183,9 +185,11 @@ extras["test"] = deps_list( "pytest", "pytest-timeout", "pytest-xdist", + "safetensors", + "sentencepiece", "scipy", "torchvision", - "transformers" + "transformers", ) extras["torch"] = deps_list("torch", "accelerate") @@ -210,7 +214,7 @@ install_requires = [ setup( name="diffusers", - version="0.8.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) + version="0.9.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) description="Diffusers", long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 86eda7371f..93f2f3a13a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -9,7 +9,7 @@ from .utils import ( ) -__version__ = "0.8.0.dev0" +__version__ = "0.9.0" from .configuration_utils import ConfigMixin from .onnx_utils import OnnxRuntimeModel @@ -46,6 +46,7 @@ if is_torch_available(): DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, + HeunDiscreteScheduler, IPNDMScheduler, KarrasVeScheduler, PNDMScheduler, @@ -65,12 +66,21 @@ else: if is_torch_available() and is_transformers_available(): from .pipelines import ( + AltDiffusionImg2ImgPipeline, + AltDiffusionPipeline, CycleDiffusionPipeline, LDMTextToImagePipeline, + StableDiffusionImageVariationPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy, StableDiffusionPipeline, + StableDiffusionPipelineSafe, + StableDiffusionUpscalePipeline, + VersatileDiffusionDualGuidedPipeline, + VersatileDiffusionImageVariationPipeline, + VersatileDiffusionPipeline, + VersatileDiffusionTextToImagePipeline, VQDiffusionPipeline, ) else: @@ -80,6 +90,7 @@ if is_torch_available() and is_transformers_available() and is_onnx_available(): from .pipelines import ( OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionInpaintPipeline, + OnnxStableDiffusionInpaintPipelineLegacy, OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline, ) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index fc6ac9b5b9..f06586b236 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -29,7 +29,7 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R from requests import HTTPError from . import __version__ -from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging +from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, logging logger = logging.get_logger(__name__) @@ -37,6 +37,38 @@ logger = logging.get_logger(__name__) _re_configuration_file = re.compile(r"config\.(.*)\.json") +class FrozenDict(OrderedDict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + for key, value in self.items(): + setattr(self, key, value) + + self.__frozen = True + + def __delitem__(self, *args, **kwargs): + raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") + + def setdefault(self, *args, **kwargs): + raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") + + def pop(self, *args, **kwargs): + raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") + + def update(self, *args, **kwargs): + raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") + + def __setattr__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") + super().__setattr__(name, value) + + def __setitem__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") + super().__setitem__(name, value) + + class ConfigMixin: r""" Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all @@ -48,21 +80,21 @@ class ConfigMixin: - **config_name** (`str`) -- A filename under which the config should stored when calling [`~ConfigMixin.save_config`] (should be overridden by parent class). - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be - overridden by parent class). - - **_compatible_classes** (`List[str]`) -- A list of classes that are compatible with the parent class, so that - `from_config` can be used from a class different than the one used to save the config (should be overridden - by parent class). + overridden by subclass). + - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass). + - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function + should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by + subclass). """ config_name = None ignore_for_config = [] - _compatible_classes = [] + has_compatibles = False + + _deprecated_kwargs = [] def register_to_config(self, **kwargs): if self.config_name is None: raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") - kwargs["_class_name"] = self.__class__.__name__ - kwargs["_diffusers_version"] = __version__ - # Special case for `kwargs` used in deprecation warning added to schedulers # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument, # or solve in a more general way. @@ -104,9 +136,103 @@ class ConfigMixin: logger.info(f"Configuration saved in {output_config_file}") @classmethod - def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs): + def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs): r""" - Instantiate a Python class from a pre-defined JSON-file. + Instantiate a Python class from a config dictionary + + Parameters: + config (`Dict[str, Any]`): + A config dictionary from which the Python class will be instantiated. Make sure to only load + configuration files of compatible classes. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + Whether kwargs that are not consumed by the Python class should be returned or not. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the Python class. + `**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually + overwrite same named arguments of `config`. + + Examples: + + ```python + >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler + + >>> # Download scheduler from huggingface.co and cache. + >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32") + + >>> # Instantiate DDIM scheduler class with same config as DDPM + >>> scheduler = DDIMScheduler.from_config(scheduler.config) + + >>> # Instantiate PNDM scheduler class with same config as DDPM + >>> scheduler = PNDMScheduler.from_config(scheduler.config) + ``` + """ + # <===== TO BE REMOVED WITH DEPRECATION + # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated + if "pretrained_model_name_or_path" in kwargs: + config = kwargs.pop("pretrained_model_name_or_path") + + if config is None: + raise ValueError("Please make sure to provide a config as the first positional argument.") + # ======> + + if not isinstance(config, dict): + deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`." + if "Scheduler" in cls.__name__: + deprecation_message += ( + f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead." + " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will" + " be removed in v1.0.0." + ) + elif "Model" in cls.__name__: + deprecation_message += ( + f"If you were trying to load a model, please use {cls}.load_config(...) followed by" + f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary" + " instead. This functionality will be removed in v1.0.0." + ) + deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False) + config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs) + + init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs) + + # Allow dtype to be specified on initialization + if "dtype" in unused_kwargs: + init_dict["dtype"] = unused_kwargs.pop("dtype") + + # add possible deprecated kwargs + for deprecated_kwarg in cls._deprecated_kwargs: + if deprecated_kwarg in unused_kwargs: + init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg) + + # Return model and optionally state and/or unused_kwargs + model = cls(**init_dict) + + # make sure to also save config parameters that might be used for compatible classes + model.register_to_config(**hidden_dict) + + # add hidden kwargs of compatible classes to unused_kwargs + unused_kwargs = {**unused_kwargs, **hidden_dict} + + if return_unused_kwargs: + return (model, unused_kwargs) + else: + return model + + @classmethod + def get_config_dict(cls, *args, **kwargs): + deprecation_message = ( + f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be" + " removed in version v1.0.0" + ) + deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False) + return cls.load_config(*args, **kwargs) + + @classmethod + def load_config( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + r""" + Instantiate a Python class from a config dictionary Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): @@ -120,10 +246,6 @@ class ConfigMixin: cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): - Whether or not to raise an error if some of the weights from the checkpoint do not have the same size - as the weights of the model (if for instance, you are instantiating a model with 10 labels from a - checkpoint with 3 labels). force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. @@ -161,33 +283,7 @@ class ConfigMixin: use this method in a firewalled environment. - """ - config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) - init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) - - # Allow dtype to be specified on initialization - if "dtype" in unused_kwargs: - init_dict["dtype"] = unused_kwargs.pop("dtype") - - # Return model and optionally state and/or unused_kwargs - model = cls(**init_dict) - return_tuple = (model,) - - # Flax schedulers have a state, so return it. - if cls.__name__.startswith("Flax") and hasattr(model, "create_state") and getattr(model, "has_state", False): - state = model.create_state() - return_tuple += (state,) - - if return_unused_kwargs: - return return_tuple + (unused_kwargs,) - else: - return return_tuple if len(return_tuple) > 1 else model - - @classmethod - def get_config_dict( - cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) @@ -283,6 +379,9 @@ class ConfigMixin: except (json.JSONDecodeError, UnicodeDecodeError): raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.") + if return_unused_kwargs: + return config_dict, kwargs + return config_dict @staticmethod @@ -291,6 +390,9 @@ class ConfigMixin: @classmethod def extract_init_dict(cls, config_dict, **kwargs): + # 0. Copy origin config dict + original_dict = {k: v for k, v in config_dict.items()} + # 1. Retrieve expected config attributes from __init__ signature expected_keys = cls._get_init_keys(cls) expected_keys.remove("self") @@ -310,10 +412,11 @@ class ConfigMixin: # load diffusers library to import compatible and original scheduler diffusers_library = importlib.import_module(__name__.split(".")[0]) - # remove attributes from compatible classes that orig cannot expect - compatible_classes = [getattr(diffusers_library, c, None) for c in cls._compatible_classes] - # filter out None potentially undefined dummy classes - compatible_classes = [c for c in compatible_classes if c is not None] + if cls.has_compatibles: + compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)] + else: + compatible_classes = [] + expected_keys_comp_cls = set() for c in compatible_classes: expected_keys_c = cls._get_init_keys(c) @@ -364,7 +467,10 @@ class ConfigMixin: # 6. Define unused keyword arguments unused_kwargs = {**config_dict, **kwargs} - return init_dict, unused_kwargs + # 7. Define "hidden" config parameters that were saved for compatible classes + hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict} + + return init_dict, unused_kwargs, hidden_config_dict @classmethod def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): @@ -377,6 +483,12 @@ class ConfigMixin: @property def config(self) -> Dict[str, Any]: + """ + Returns the config of the class as a frozen dictionary + + Returns: + `Dict[str, Any]`: Config of the class. + """ return self._internal_dict def to_json_string(self) -> str: @@ -387,6 +499,9 @@ class ConfigMixin: `str`: String containing all the attributes that make up this configuration instance in JSON format. """ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {} + config_dict["_class_name"] = self.__class__.__name__ + config_dict["_diffusers_version"] = __version__ + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" def to_json_file(self, json_file_path: Union[str, os.PathLike]): @@ -401,38 +516,6 @@ class ConfigMixin: writer.write(self.to_json_string()) -class FrozenDict(OrderedDict): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - for key, value in self.items(): - setattr(self, key, value) - - self.__frozen = True - - def __delitem__(self, *args, **kwargs): - raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") - - def setdefault(self, *args, **kwargs): - raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") - - def pop(self, *args, **kwargs): - raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") - - def update(self, *args, **kwargs): - raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") - - def __setattr__(self, name, value): - if hasattr(self, "__frozen") and self.__frozen: - raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") - super().__setattr__(name, value) - - def __setitem__(self, name, value): - if hasattr(self, "__frozen") and self.__frozen: - raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") - super().__setitem__(name, value) - - def register_to_config(init): r""" Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are @@ -446,7 +529,7 @@ def register_to_config(init): def inner_init(self, *args, **kwargs): # Ignore private kwargs in the init. init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} - init(self, *args, **init_kwargs) + config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")} if not isinstance(self, ConfigMixin): raise RuntimeError( f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " @@ -471,7 +554,9 @@ def register_to_config(init): if k not in ignore and k not in new_kwargs } ) + new_kwargs = {**config_init_kwargs, **new_kwargs} getattr(self, "register_to_config")(**new_kwargs) + init(self, *args, **init_kwargs) return inner_init @@ -488,7 +573,7 @@ def flax_register_to_config(cls): ) # Ignore private kwargs in the init. Retrieve all passed attributes - init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + init_kwargs = {k: v for k, v in kwargs.items()} # Retrieve default values fields = dataclasses.fields(self) diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 59e13da0f2..2fd6bfa1fa 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -2,7 +2,7 @@ # 1. modify the `_deps` dict in setup.py # 2. run `make deps_table_update`` deps = { - "Pillow": "Pillow<10.0", + "Pillow": "Pillow", "accelerate": "accelerate>=0.11.0", "black": "black==22.8", "datasets": "datasets", @@ -21,6 +21,8 @@ deps = { "pytest": "pytest", "pytest-timeout": "pytest-timeout", "pytest-xdist": "pytest-xdist", + "safetensors": "safetensors", + "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", "scipy": "scipy", "regex": "regex!=2019.12.17", "requests": "requests", diff --git a/src/diffusers/experimental/rl/value_guided_sampling.py b/src/diffusers/experimental/rl/value_guided_sampling.py index 8d5062e3d4..4dd935f54d 100644 --- a/src/diffusers/experimental/rl/value_guided_sampling.py +++ b/src/diffusers/experimental/rl/value_guided_sampling.py @@ -89,6 +89,7 @@ class ValueGuidedRLPipeline(DiffusionPipeline): x = x + scale * grad x = self.reset_x0(x, conditions, self.action_dim) prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) + # TODO: set prediction_type when instantiating the model x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] # apply conditions to the trajectory diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 5ef1002249..857fdd1b0b 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -332,7 +332,7 @@ class FlaxModelMixin: elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)): raise EnvironmentError( f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model" - " using `from_pt=True`." + " using `from_pt=True`." ) else: raise EnvironmentError( diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 1e91ccd56a..5f79e7fe01 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -30,8 +30,10 @@ from .utils import ( CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, + SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, is_accelerate_available, + is_safetensors_available, is_torch_version, logging, ) @@ -51,6 +53,9 @@ if is_accelerate_available(): from accelerate.utils import set_module_tensor_to_device from accelerate.utils.versions import is_torch_version +if is_safetensors_available(): + import safetensors + def get_parameter_device(parameter: torch.nn.Module): try: @@ -84,10 +89,13 @@ def get_parameter_dtype(parameter: torch.nn.Module): def load_state_dict(checkpoint_file: Union[str, os.PathLike]): """ - Reads a PyTorch checkpoint file, returning properly formatted errors if they arise. + Reads a checkpoint file, returning properly formatted errors if they arise. """ try: - return torch.load(checkpoint_file, map_location="cpu") + if os.path.basename(checkpoint_file) == WEIGHTS_NAME: + return torch.load(checkpoint_file, map_location="cpu") + else: + return safetensors.torch.load_file(checkpoint_file, device="cpu") except Exception as e: try: with open(checkpoint_file) as f: @@ -104,7 +112,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): ) from e except (UnicodeDecodeError, ValueError): raise OSError( - f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' " + f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. " "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True." ) @@ -332,7 +340,7 @@ class ModelMixin(torch.nn.Module): if low_cpu_mem_usage and not is_accelerate_available(): low_cpu_mem_usage = False - logger.warn( + logger.warning( "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" @@ -375,80 +383,44 @@ class ModelMixin(torch.nn.Module): # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the # Load model - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - if os.path.isdir(pretrained_model_name_or_path): - if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): - # Load from a PyTorch checkpoint - model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) - elif subfolder is not None and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME) - ): - model_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME) - else: - raise EnvironmentError( - f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}." - ) - else: + + model_file = None + if is_safetensors_available(): try: - # Load from URL or cache if already cached - model_file = hf_hub_download( + model_file = _get_model_file( pretrained_model_name_or_path, - filename=WEIGHTS_NAME, + weights_name=SAFETENSORS_WEIGHTS_NAME, cache_dir=cache_dir, force_download=force_download, - proxies=proxies, resume_download=resume_download, + proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, - user_agent=user_agent, - subfolder=subfolder, revision=revision, + subfolder=subfolder, + user_agent=user_agent, ) - - except RepositoryNotFoundError: - raise EnvironmentError( - f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " - "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " - "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " - "login`." - ) - except RevisionNotFoundError: - raise EnvironmentError( - f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " - "this model name. Check the model page at " - f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." - ) - except EntryNotFoundError: - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}." - ) - except HTTPError as err: - raise EnvironmentError( - "There was a specific connection error when trying to load" - f" {pretrained_model_name_or_path}:\n{err}" - ) - except ValueError: - raise EnvironmentError( - f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" - f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" - f" directory containing a file named {WEIGHTS_NAME} or" - " \nCheckout your internet connection or see how to run the library in" - " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." - ) - except EnvironmentError: - raise EnvironmentError( - f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " - "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " - f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " - f"containing a file named {WEIGHTS_NAME}" - ) - - # restore default dtype + except: + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) if low_cpu_mem_usage: # Instantiate model with empty weights with accelerate.init_empty_weights(): - model, unused_kwargs = cls.from_config( + config, unused_kwargs = cls.load_config( config_path, cache_dir=cache_dir, return_unused_kwargs=True, @@ -462,6 +434,7 @@ class ModelMixin(torch.nn.Module): device_map=device_map, **kwargs, ) + model = cls.from_config(config, **unused_kwargs) # if device_map is Non,e load the state dict on move the params from meta device to the cpu if device_map is None: @@ -482,7 +455,7 @@ class ModelMixin(torch.nn.Module): "error_msgs": [], } else: - model, unused_kwargs = cls.from_config( + config, unused_kwargs = cls.load_config( config_path, cache_dir=cache_dir, return_unused_kwargs=True, @@ -496,6 +469,7 @@ class ModelMixin(torch.nn.Module): device_map=device_map, **kwargs, ) + model = cls.from_config(config, **unused_kwargs) state_dict = load_state_dict(model_file) model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( @@ -689,3 +663,88 @@ def unwrap_model(model: torch.nn.Module) -> torch.nn.Module: return unwrap_model(model.module) else: return model + + +def _get_model_file( + pretrained_model_name_or_path, + *, + weights_name, + subfolder, + cache_dir, + force_download, + proxies, + resume_download, + local_files_only, + use_auth_token, + user_agent, + revision, +): + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): + # Load from a PyTorch checkpoint + model_file = os.path.join(pretrained_model_name_or_path, weights_name) + return model_file + elif subfolder is not None and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, weights_name) + ): + model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name) + return model_file + else: + raise EnvironmentError( + f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}." + ) + else: + try: + # Load from URL or cache if already cached + model_file = hf_hub_download( + pretrained_model_name_or_path, + filename=weights_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision, + ) + return model_file + + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " + "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " + "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " + "login`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " + "this model name. Check the model page at " + f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}." + ) + except HTTPError as err: + raise EnvironmentError( + f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a file named {weights_name} or" + " \nCheckout your internet connection or see how to run the library in" + " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a file named {weights_name}" + ) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index be9203b4d6..e9454a467a 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +import warnings from dataclasses import dataclass from typing import Optional @@ -98,8 +99,11 @@ class Transformer2DModel(ModelMixin, ConfigMixin): num_vector_embeds: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, ): super().__init__() + self.use_linear_projection = use_linear_projection self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim inner_dim = num_attention_heads * attention_head_dim @@ -125,7 +129,10 @@ class Transformer2DModel(ModelMixin, ConfigMixin): self.in_channels = in_channels self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) - self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" @@ -151,6 +158,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): activation_fn=activation_fn, num_embeds_ada_norm=num_embeds_ada_norm, attention_bias=attention_bias, + only_cross_attention=only_cross_attention, ) for d in range(num_layers) ] @@ -158,7 +166,10 @@ class Transformer2DModel(ModelMixin, ConfigMixin): # 4. Define output layers if self.is_input_continuous: - self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: self.norm_out = nn.LayerNorm(inner_dim) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) @@ -190,10 +201,16 @@ class Transformer2DModel(ModelMixin, ConfigMixin): if self.is_input_continuous: batch, channel, height, weight = hidden_states.shape residual = hidden_states + hidden_states = self.norm(hidden_states) - hidden_states = self.proj_in(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) @@ -203,8 +220,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin): # 3. Output if self.is_input_continuous: - hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) - hidden_states = self.proj_out(hidden_states) + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + output = hidden_states + residual elif self.is_input_vectorized: hidden_states = self.norm_out(hidden_states) @@ -284,22 +310,52 @@ class AttentionBlock(nn.Module): key_proj = self.key(hidden_states) value_proj = self.value(hidden_states) - # transpose - query_states = self.transpose_for_scores(query_proj) - key_states = self.transpose_for_scores(key_proj) - value_states = self.transpose_for_scores(value_proj) + scale = 1 / math.sqrt(self.channels / self.num_heads) # get scores - scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) - attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm + if self.num_heads > 1: + query_states = self.transpose_for_scores(query_proj) + key_states = self.transpose_for_scores(key_proj) + value_states = self.transpose_for_scores(value_proj) + + # TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors? + # or reformulate this into a 3D problem? + # TODO: measure whether on MPS device it would be faster to do this matmul via einsum + # as some matmuls can be 1.94x slower than an equivalent einsum on MPS + # https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0 + attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * scale + else: + query_states, key_states, value_states = query_proj, key_proj, value_proj + + attention_scores = torch.baddbmm( + torch.empty( + query_states.shape[0], + query_states.shape[1], + key_states.shape[1], + dtype=query_states.dtype, + device=query_states.device, + ), + query_states, + key_states.transpose(-1, -2), + beta=0, + alpha=scale, + ) + attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) # compute attention output - hidden_states = torch.matmul(attention_probs, value_states) - - hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() - new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) - hidden_states = hidden_states.view(new_hidden_states_shape) + if self.num_heads > 1: + # TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors? + # or reformulate this into a 3D problem? + # TODO: measure whether on MPS device it would be faster to do this matmul via einsum + # as some matmuls can be 1.94x slower than an equivalent einsum on MPS + # https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0 + hidden_states = torch.matmul(attention_probs, value_states) + hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() + new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) + hidden_states = hidden_states.view(new_hidden_states_shape) + else: + hidden_states = torch.bmm(attention_probs, value_states) # compute next hidden_states hidden_states = self.proj_attn(hidden_states) @@ -337,14 +393,17 @@ class BasicTransformerBlock(nn.Module): activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, attention_bias: bool = False, + only_cross_attention: bool = False, ): super().__init__() + self.only_cross_attention = only_cross_attention self.attn1 = CrossAttention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, ) # is a self-attention self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) self.attn2 = CrossAttention( @@ -366,6 +425,16 @@ class BasicTransformerBlock(nn.Module): self.norm2 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim) + # if xformers is installed try to use memory_efficient_attention by default + if is_xformers_available(): + try: + self._set_use_memory_efficient_attention_xformers(True) + except Exception as e: + warnings.warn( + "Could not enable memory efficient attention. Make sure xformers is installed" + f" correctly and a GPU is available: {e}" + ) + def _set_attention_slice(self, slice_size): self.attn1._slice_size = slice_size self.attn2._slice_size = slice_size @@ -401,7 +470,11 @@ class BasicTransformerBlock(nn.Module): norm_hidden_states = ( self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) ) - hidden_states = self.attn1(norm_hidden_states) + hidden_states + + if self.only_cross_attention: + hidden_states = self.attn1(norm_hidden_states, context) + hidden_states + else: + hidden_states = self.attn1(norm_hidden_states) + hidden_states # 2. Cross-Attention norm_hidden_states = ( @@ -507,19 +580,17 @@ class CrossAttention(nn.Module): return hidden_states def _attention(self, query, key, value): - # TODO: use baddbmm for better performance - if query.device.type == "mps": - # Better performance on mps (~20-25%) - attention_scores = torch.einsum("b i d, b j d -> b i j", query, key) * self.scale - else: - attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) attention_probs = attention_scores.softmax(dim=-1) # compute attention output - if query.device.type == "mps": - hidden_states = torch.einsum("b i j, b j d -> b i d", attention_probs, value) - else: - hidden_states = torch.matmul(attention_probs, value) + hidden_states = torch.bmm(attention_probs, value) # reshape hidden_states hidden_states = self.reshape_batch_dim_to_heads(hidden_states) @@ -534,21 +605,15 @@ class CrossAttention(nn.Module): for i in range(hidden_states.shape[0] // slice_size): start_idx = i * slice_size end_idx = (i + 1) * slice_size - if query.device.type == "mps": - # Better performance on mps (~20-25%) - attn_slice = ( - torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) - * self.scale - ) - else: - attn_slice = ( - torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale - ) # TODO: use baddbmm for better performance + attn_slice = torch.baddbmm( + torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query[start_idx:end_idx], + key[start_idx:end_idx].transpose(-1, -2), + beta=0, + alpha=self.scale, + ) attn_slice = attn_slice.softmax(dim=-1) - if query.device.type == "mps": - attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx]) - else: - attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice @@ -666,3 +731,129 @@ class AdaLayerNorm(nn.Module): scale, shift = torch.chunk(emb, 2) x = self.norm(x) * (1 + scale) + shift return x + + +class DualTransformer2DModel(nn.Module): + """ + Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of context dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + ): + super().__init__() + self.transformers = nn.ModuleList( + [ + Transformer2DModel( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + num_layers=num_layers, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + sample_size=sample_size, + num_vector_embeds=num_vector_embeds, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + ) + for _ in range(2) + ] + ) + + # Variables that can be set by a pipeline: + + # The ratio of transformer1 to transformer2's output states to be combined during inference + self.mix_ratio = 0.5 + + # The shape of `encoder_hidden_states` is expected to be + # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)` + self.condition_lengths = [77, 257] + + # Which transformer to use to encode which condition. + # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])` + self.transformer_index_for_condition = [1, 0] + + def forward(self, hidden_states, encoder_hidden_states, timestep=None, return_dict: bool = True): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`] + if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample + tensor. + """ + input_states = hidden_states + + encoded_states = [] + tokens_start = 0 + for i in range(2): + # for each of the two transformers, pass the corresponding condition tokens + condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] + transformer_index = self.transformer_index_for_condition[i] + encoded_state = self.transformers[transformer_index](input_states, condition_state, timestep, return_dict)[ + 0 + ] + encoded_states.append(encoded_state - input_states) + tokens_start += self.condition_lengths[i] + + output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) + output_states = output_states + input_states + + if not return_dict: + return (output_states,) + + return Transformer2DModelOutput(sample=output_states) + + def _set_attention_slice(self, slice_size): + for transformer in self.transformers: + transformer._set_attention_slice(slice_size) + + def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for transformer in self.transformers: + transformer._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 1b86094747..71106e0545 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -104,6 +104,8 @@ class FlaxBasicTransformerBlock(nn.Module): Hidden states dimension inside each head dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate + only_cross_attention (`bool`, defaults to `False`): + Whether to only apply cross attention. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ @@ -111,10 +113,11 @@ class FlaxBasicTransformerBlock(nn.Module): n_heads: int d_head: int dropout: float = 0.0 + only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): - # self attention + # self attention (or cross_attention if only_cross_attention is True) self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) # cross attention self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) @@ -126,7 +129,10 @@ class FlaxBasicTransformerBlock(nn.Module): def __call__(self, hidden_states, context, deterministic=True): # self attention residual = hidden_states - hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic) + if self.only_cross_attention: + hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic) + else: + hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic) hidden_states = hidden_states + residual # cross attention @@ -159,6 +165,8 @@ class FlaxTransformer2DModel(nn.Module): Number of transformers block dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate + use_linear_projection (`bool`, defaults to `False`): tbd + only_cross_attention (`bool`, defaults to `False`): tbd dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ @@ -167,49 +175,70 @@ class FlaxTransformer2DModel(nn.Module): d_head: int depth: int = 1 dropout: float = 0.0 + use_linear_projection: bool = False + only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) inner_dim = self.n_heads * self.d_head - self.proj_in = nn.Conv( - inner_dim, - kernel_size=(1, 1), - strides=(1, 1), - padding="VALID", - dtype=self.dtype, - ) + if self.use_linear_projection: + self.proj_in = nn.Dense(inner_dim, dtype=self.dtype) + else: + self.proj_in = nn.Conv( + inner_dim, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) self.transformer_blocks = [ - FlaxBasicTransformerBlock(inner_dim, self.n_heads, self.d_head, dropout=self.dropout, dtype=self.dtype) + FlaxBasicTransformerBlock( + inner_dim, + self.n_heads, + self.d_head, + dropout=self.dropout, + only_cross_attention=self.only_cross_attention, + dtype=self.dtype, + ) for _ in range(self.depth) ] - self.proj_out = nn.Conv( - inner_dim, - kernel_size=(1, 1), - strides=(1, 1), - padding="VALID", - dtype=self.dtype, - ) + if self.use_linear_projection: + self.proj_out = nn.Dense(inner_dim, dtype=self.dtype) + else: + self.proj_out = nn.Conv( + inner_dim, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) def __call__(self, hidden_states, context, deterministic=True): batch, height, width, channels = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) - hidden_states = self.proj_in(hidden_states) - - hidden_states = hidden_states.reshape(batch, height * width, channels) + if self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height * width, channels) + hidden_states = self.proj_in(hidden_states) + else: + hidden_states = self.proj_in(hidden_states) + hidden_states = hidden_states.reshape(batch, height * width, channels) for transformer_block in self.transformer_blocks: hidden_states = transformer_block(hidden_states, context, deterministic=deterministic) - hidden_states = hidden_states.reshape(batch, height, width, channels) + if self.use_linear_projection: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, channels) + else: + hidden_states = hidden_states.reshape(batch, height, width, channels) + hidden_states = self.proj_out(hidden_states) - hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states + residual - return hidden_states diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 0432405760..5b337f482c 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -43,8 +43,8 @@ class UNet2DModel(ModelMixin, ConfigMixin): implements for all the model (such as downloading or saving, etc.) Parameters: - sample_size (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*): - Input sample size. + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. @@ -71,7 +71,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, - sample_size: Optional[int] = None, + sample_size: Optional[Union[int, Tuple[int, int]]] = None, in_channels: int = 3, out_channels: int = 3, center_input_sample: bool = False, @@ -175,7 +175,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps) self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) def forward( self, @@ -209,6 +209,11 @@ class UNet2DModel(ModelMixin, ConfigMixin): timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb) # 2. pre-process @@ -242,9 +247,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): sample = upsample_block(sample, res_samples, emb) # 6. post-process - # make sure hidden states is in float32 - # when running in half-precision - sample = self.conv_norm_out(sample.float()).type(sample.dtype) + sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 770043f053..cce7e7fd5a 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -15,7 +15,7 @@ import numpy as np import torch from torch import nn -from .attention import AttentionBlock, Transformer2DModel +from .attention import AttentionBlock, DualTransformer2DModel, Transformer2DModel from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D @@ -32,6 +32,9 @@ def get_down_block( resnet_groups=None, cross_attention_dim=None, downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock2D": @@ -74,6 +77,9 @@ def get_down_block( downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, ) elif down_block_type == "SkipDownBlock2D": return SkipDownBlock2D( @@ -137,6 +143,9 @@ def get_up_block( attn_num_head_channels, resnet_groups=None, cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock2D": @@ -166,6 +175,9 @@ def get_up_block( resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, ) elif up_block_type == "AttnUpBlock2D": return AttnUpBlock2D( @@ -242,7 +254,6 @@ class UNetMidBlock2D(nn.Module): attn_num_head_channels=1, attention_type="default", output_scale_factor=1.0, - **kwargs, ): super().__init__() @@ -322,7 +333,8 @@ class UNetMidBlock2DCrossAttn(nn.Module): attention_type="default", output_scale_factor=1.0, cross_attention_dim=1280, - **kwargs, + dual_cross_attention=False, + use_linear_projection=False, ): super().__init__() @@ -348,16 +360,29 @@ class UNetMidBlock2DCrossAttn(nn.Module): attentions = [] for _ in range(num_layers): - attentions.append( - Transformer2DModel( - attn_num_head_channels, - in_channels // attn_num_head_channels, - in_channels=in_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) ) - ) resnets.append( ResnetBlock2D( in_channels=in_channels, @@ -377,15 +402,17 @@ class UNetMidBlock2DCrossAttn(nn.Module): self.resnets = nn.ModuleList(resnets) def set_attention_slice(self, slice_size): - if slice_size is not None and self.attn_num_head_channels % slice_size != 0: + head_dims = self.attn_num_head_channels + head_dims = [head_dims] if isinstance(head_dims, int) else head_dims + if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): raise ValueError( - f"Make sure slice_size {slice_size} is a divisor of " - f"the number of heads used in cross_attention {self.attn_num_head_channels}" + f"Make sure slice_size {slice_size} is a common divisor of " + f"the number of heads used in cross_attention: {head_dims}" ) - if slice_size is not None and slice_size > self.attn_num_head_channels: + if slice_size is not None and slice_size > min(head_dims): raise ValueError( - f"Chunk_size {slice_size} has to be smaller or equal to " - f"the number of heads used in cross_attention {self.attn_num_head_channels}" + f"slice_size {slice_size} has to be smaller or equal to " + f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" ) for attn in self.attentions: @@ -505,6 +532,9 @@ class CrossAttnDownBlock2D(nn.Module): output_scale_factor=1.0, downsample_padding=1, add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, ): super().__init__() resnets = [] @@ -529,16 +559,30 @@ class CrossAttnDownBlock2D(nn.Module): pre_norm=resnet_pre_norm, ) ) - attentions.append( - Transformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) ) - ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) @@ -556,15 +600,17 @@ class CrossAttnDownBlock2D(nn.Module): self.gradient_checkpointing = False def set_attention_slice(self, slice_size): - if slice_size is not None and self.attn_num_head_channels % slice_size != 0: + head_dims = self.attn_num_head_channels + head_dims = [head_dims] if isinstance(head_dims, int) else head_dims + if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): raise ValueError( - f"Make sure slice_size {slice_size} is a divisor of " - f"the number of heads used in cross_attention {self.attn_num_head_channels}" + f"Make sure slice_size {slice_size} is a common divisor of " + f"the number of heads used in cross_attention: {head_dims}" ) - if slice_size is not None and slice_size > self.attn_num_head_channels: + if slice_size is not None and slice_size > min(head_dims): raise ValueError( - f"Chunk_size {slice_size} has to be smaller or equal to " - f"the number of heads used in cross_attention {self.attn_num_head_channels}" + f"slice_size {slice_size} has to be smaller or equal to " + f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" ) for attn in self.attentions: @@ -1089,6 +1135,9 @@ class CrossAttnUpBlock2D(nn.Module): attention_type="default", output_scale_factor=1.0, add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, ): super().__init__() resnets = [] @@ -1115,16 +1164,30 @@ class CrossAttnUpBlock2D(nn.Module): pre_norm=resnet_pre_norm, ) ) - attentions.append( - Transformer2DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) ) - ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) @@ -1136,15 +1199,17 @@ class CrossAttnUpBlock2D(nn.Module): self.gradient_checkpointing = False def set_attention_slice(self, slice_size): - if slice_size is not None and self.attn_num_head_channels % slice_size != 0: + head_dims = self.attn_num_head_channels + head_dims = [head_dims] if isinstance(head_dims, int) else head_dims + if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): raise ValueError( - f"Make sure slice_size {slice_size} is a divisor of " - f"the number of heads used in cross_attention {self.attn_num_head_channels}" + f"Make sure slice_size {slice_size} is a common divisor of " + f"the number of heads used in cross_attention: {head_dims}" ) - if slice_size is not None and slice_size > self.attn_num_head_channels: + if slice_size is not None and slice_size > min(head_dims): raise ValueError( - f"Chunk_size {slice_size} has to be smaller or equal to " - f"the number of heads used in cross_attention {self.attn_num_head_channels}" + f"slice_size {slice_size} has to be smaller or equal to " + f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" ) for attn in self.attentions: diff --git a/src/diffusers/models/unet_2d_blocks_flax.py b/src/diffusers/models/unet_2d_blocks_flax.py index 5798385b9d..96e76cb06a 100644 --- a/src/diffusers/models/unet_2d_blocks_flax.py +++ b/src/diffusers/models/unet_2d_blocks_flax.py @@ -46,6 +46,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module): num_layers: int = 1 attn_num_head_channels: int = 1 add_downsample: bool = True + use_linear_projection: bool = False + only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): @@ -68,6 +70,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module): n_heads=self.attn_num_head_channels, d_head=self.out_channels // self.attn_num_head_channels, depth=1, + use_linear_projection=self.use_linear_projection, + only_cross_attention=self.only_cross_attention, dtype=self.dtype, ) attentions.append(attn_block) @@ -178,6 +182,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module): num_layers: int = 1 attn_num_head_channels: int = 1 add_upsample: bool = True + use_linear_projection: bool = False + only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): @@ -201,6 +207,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module): n_heads=self.attn_num_head_channels, d_head=self.out_channels // self.attn_num_head_channels, depth=1, + use_linear_projection=self.use_linear_projection, + only_cross_attention=self.only_cross_attention, dtype=self.dtype, ) attentions.append(attn_block) @@ -310,6 +318,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): dropout: float = 0.0 num_layers: int = 1 attn_num_head_channels: int = 1 + use_linear_projection: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): @@ -331,6 +340,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): n_heads=self.attn_num_head_channels, d_head=self.in_channels // self.attn_num_head_channels, depth=1, + use_linear_projection=self.use_linear_projection, dtype=self.dtype, ) attentions.append(attn_block) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index becae75683..1b43f960d9 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -56,11 +56,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): implements for all the models (such as downloading or saving, etc.) Parameters: - sample_size (`int`, *optional*): The size of the input sample. + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. - flip_sin_to_cos (`bool`, *optional*, defaults to `True`): + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): @@ -97,6 +98,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): "DownBlock2D", ), up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, @@ -105,7 +107,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): norm_num_groups: int = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, - attention_head_dim: int = 8, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + num_class_embeds: Optional[int] = None, ): super().__init__() @@ -121,10 +126,20 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + # class embedding + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + self.down_blocks = nn.ModuleList([]) self.mid_block = None self.up_blocks = nn.ModuleList([]) + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): @@ -143,8 +158,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim, + attn_num_head_channels=attention_head_dim[i], downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], ) self.down_blocks.append(down_block) @@ -157,8 +175,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift="default", cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim, + attn_num_head_channels=attention_head_dim[-1], resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, ) # count how many layers upsample the images @@ -166,6 +186,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): # up reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 @@ -193,7 +215,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -201,18 +226,20 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): # out self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) self.conv_act = nn.SiLU() - self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) def set_attention_slice(self, slice_size): - if slice_size is not None and self.config.attention_head_dim % slice_size != 0: + head_dims = self.config.attention_head_dim + head_dims = [head_dims] if isinstance(head_dims, int) else head_dims + if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): raise ValueError( - f"Make sure slice_size {slice_size} is a divisor of " - f"the number of heads used in cross_attention {self.config.attention_head_dim}" + f"Make sure slice_size {slice_size} is a common divisor of " + f"the number of heads used in cross_attention: {head_dims}" ) - if slice_size is not None and slice_size > self.config.attention_head_dim: + if slice_size is not None and slice_size > min(head_dims): raise ValueError( - f"Chunk_size {slice_size} has to be smaller or equal to " - f"the number of heads used in cross_attention {self.config.attention_head_dim}" + f"slice_size {slice_size} has to be smaller or equal to " + f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" ) for block in self.down_blocks: @@ -245,6 +272,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" @@ -297,6 +325,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb) + if self.config.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + # 2. pre-process sample = self.conv_in(sample) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index f0e721826b..8a33853700 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -79,7 +79,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. - attention_head_dim (`int`, *optional*, defaults to 8): + attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8): The dimension of the attention heads. cross_attention_dim (`int`, *optional*, defaults to 768): The dimension of the cross attention features. @@ -97,11 +97,13 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): "DownBlock2D", ) up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") + only_cross_attention: Union[bool, Tuple[bool]] = False block_out_channels: Tuple[int] = (320, 640, 1280, 1280) layers_per_block: int = 2 - attention_head_dim: int = 8 + attention_head_dim: Union[int, Tuple[int]] = 8 cross_attention_dim: int = 1280 dropout: float = 0.0 + use_linear_projection: bool = False dtype: jnp.dtype = jnp.float32 freq_shift: int = 0 @@ -134,6 +136,14 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): self.time_proj = FlaxTimesteps(block_out_channels[0], freq_shift=self.config.freq_shift) self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) + only_cross_attention = self.only_cross_attention + if isinstance(only_cross_attention, bool): + only_cross_attention = (only_cross_attention,) * len(self.down_block_types) + + attention_head_dim = self.attention_head_dim + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(self.down_block_types) + # down down_blocks = [] output_channel = block_out_channels[0] @@ -148,8 +158,10 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): out_channels=output_channel, dropout=self.dropout, num_layers=self.layers_per_block, - attn_num_head_channels=self.attention_head_dim, + attn_num_head_channels=attention_head_dim[i], add_downsample=not is_final_block, + use_linear_projection=self.use_linear_projection, + only_cross_attention=only_cross_attention[i], dtype=self.dtype, ) else: @@ -169,13 +181,16 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): self.mid_block = FlaxUNetMidBlock2DCrossAttn( in_channels=block_out_channels[-1], dropout=self.dropout, - attn_num_head_channels=self.attention_head_dim, + attn_num_head_channels=attention_head_dim[-1], + use_linear_projection=self.use_linear_projection, dtype=self.dtype, ) # up up_blocks = [] reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(self.up_block_types): prev_output_channel = output_channel @@ -190,9 +205,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): out_channels=output_channel, prev_output_channel=prev_output_channel, num_layers=self.layers_per_block + 1, - attn_num_head_channels=self.attention_head_dim, + attn_num_head_channels=reversed_attention_head_dim[i], add_upsample=not is_final_block, dropout=self.dropout, + use_linear_projection=self.use_linear_projection, + only_cross_attention=only_cross_attention[i], dtype=self.dtype, ) else: @@ -230,9 +247,9 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ) -> Union[FlaxUNet2DConditionOutput, Tuple]: r""" Args: - sample (`jnp.ndarray`): (channel, height, width) noisy inputs tensor + sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor timestep (`jnp.ndarray` or `float` or `int`): timesteps - encoder_hidden_states (`jnp.ndarray`): (channel, height, width) encoder hidden states + encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a plain tuple. diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 30de343d08..e29f4e8afa 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -565,6 +565,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) + self.use_slicing = False def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: h = self.encoder(x) @@ -576,7 +577,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): return AutoencoderKLOutput(latent_dist=posterior) - def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: z = self.post_quant_conv(z) dec = self.decoder(z) @@ -585,6 +586,34 @@ class AutoencoderKL(ModelMixin, ConfigMixin): return DecoderOutput(sample=dec) + def enable_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + def forward( self, sample: torch.FloatTensor, diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index 4c34e64f78..f8fd304776 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -47,7 +47,7 @@ logger = logging.get_logger(__name__) LOADABLE_CLASSES = { "diffusers": { "FlaxModelMixin": ["save_pretrained", "from_pretrained"], - "FlaxSchedulerMixin": ["save_config", "from_config"], + "FlaxSchedulerMixin": ["save_pretrained", "from_pretrained"], "FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"], }, "transformers": { @@ -280,7 +280,7 @@ class FlaxDiffusionPipeline(ConfigMixin): >>> from diffusers import FlaxDPMSolverMultistepScheduler >>> model_id = "runwayml/stable-diffusion-v1-5" - >>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_config( + >>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_pretrained( ... model_id, ... subfolder="scheduler", ... ) @@ -303,7 +303,7 @@ class FlaxDiffusionPipeline(ConfigMixin): # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained if not os.path.isdir(pretrained_model_name_or_path): - config_dict = cls.get_config_dict( + config_dict = cls.load_config( pretrained_model_name_or_path, cache_dir=cache_dir, resume_download=resume_download, @@ -317,8 +317,8 @@ class FlaxDiffusionPipeline(ConfigMixin): allow_patterns = [os.path.join(k, "*") for k in folder_names] allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name] - # make sure we don't download PyTorch weights - ignore_patterns = "*.bin" + # make sure we don't download PyTorch weights, unless when using from_pt + ignore_patterns = "*.bin" if not from_pt else [] if cls != FlaxDiffusionPipeline: requested_pipeline_class = cls.__name__ @@ -349,7 +349,7 @@ class FlaxDiffusionPipeline(ConfigMixin): else: cached_folder = pretrained_model_name_or_path - config_dict = cls.get_config_dict(cached_folder) + config_dict = cls.load_config(cached_folder) # 2. Load the pipeline class, if using custom module then load it from the hub # if we load from explicit class, let's use it @@ -370,7 +370,7 @@ class FlaxDiffusionPipeline(ConfigMixin): expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} - init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) + init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) init_kwargs = {} @@ -411,13 +411,13 @@ class FlaxDiffusionPipeline(ConfigMixin): f" {expected_class_obj}" ) elif passed_class_obj[name] is None: - logger.warn( + logger.warning( f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note" f" that this might lead to problems when using {pipeline_class} and is not recommended." ) sub_model_should_be_defined = False else: - logger.warn( + logger.warning( f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" " has the correct type" ) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index a194f3eb34..01bcc6a338 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -18,6 +18,7 @@ import importlib import inspect import os from dataclasses import dataclass +from pathlib import Path from typing import Any, Dict, List, Optional, Union import numpy as np @@ -25,7 +26,7 @@ import torch import diffusers import PIL -from huggingface_hub import snapshot_download +from huggingface_hub import model_info, snapshot_download from packaging import version from PIL import Image from tqdm.auto import tqdm @@ -43,6 +44,7 @@ from .utils import ( BaseOutput, deprecate, is_accelerate_available, + is_safetensors_available, is_torch_version, is_transformers_available, logging, @@ -57,6 +59,7 @@ if is_transformers_available(): INDEX_FILE = "diffusion_pytorch_model.bin" CUSTOM_PIPELINE_FILE_NAME = "pipeline.py" DUMMY_MODULES_FOLDER = "diffusers.utils" +TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils" logger = logging.get_logger(__name__) @@ -65,7 +68,7 @@ logger = logging.get_logger(__name__) LOADABLE_CLASSES = { "diffusers": { "ModelMixin": ["save_pretrained", "from_pretrained"], - "SchedulerMixin": ["save_config", "from_config"], + "SchedulerMixin": ["save_pretrained", "from_pretrained"], "DiffusionPipeline": ["save_pretrained", "from_pretrained"], "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], }, @@ -77,6 +80,9 @@ LOADABLE_CLASSES = { "ProcessorMixin": ["save_pretrained", "from_pretrained"], "ImageProcessingMixin": ["save_pretrained", "from_pretrained"], }, + "onnxruntime.training": { + "ORTModule": ["save_pretrained", "from_pretrained"], + }, } ALL_IMPORTABLE_CLASSES = {} @@ -112,6 +118,23 @@ class AudioPipelineOutput(BaseOutput): audios: np.ndarray +def is_safetensors_compatible(info) -> bool: + filenames = set(sibling.rfilename for sibling in info.siblings) + pt_filenames = set(filename for filename in filenames if filename.endswith(".bin")) + is_safetensors_compatible = any(file.endswith(".safetensors") for file in filenames) + for pt_filename in pt_filenames: + prefix, raw = os.path.split(pt_filename) + if raw == "pytorch_model.bin": + # transformers specific + sf_filename = os.path.join(prefix, "model.safetensors") + else: + sf_filename = pt_filename[: -len(".bin")] + ".safetensors" + if is_safetensors_compatible and sf_filename not in filenames: + logger.warning(f"{sf_filename} not found") + is_safetensors_compatible = False + return is_safetensors_compatible + + class DiffusionPipeline(ConfigMixin): r""" Base class for all models. @@ -124,10 +147,13 @@ class DiffusionPipeline(ConfigMixin): Class attributes: - - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all + - **config_name** (`str`) -- name of the config file that will store the class and module names of all components of the diffusion pipeline. + - **_optional_components** (List[`str`]) -- list of all components that are optional so they don't have to be + passed for the pipeline to function (should be overridden by subclasses). """ config_name = "model_index.json" + _optional_components = [] def register_modules(self, **kwargs): # import it here to avoid circular import @@ -179,12 +205,19 @@ class DiffusionPipeline(ConfigMixin): model_index_dict.pop("_diffusers_version") model_index_dict.pop("_module", None) + expected_modules, optional_kwargs = self._get_signature_keys(self) + + def is_saveable_module(name, value): + if name not in expected_modules: + return False + if name in self._optional_components and value[0] is None: + return False + return True + + model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)} + for pipeline_component_name in model_index_dict.keys(): sub_model = getattr(self, pipeline_component_name) - if sub_model is None: - # edge case for saving a pipeline with safety_checker=None - continue - model_cls = sub_model.__class__ save_method_name = None @@ -207,7 +240,7 @@ class DiffusionPipeline(ConfigMixin): if torch_device is None: return self - module_names, _ = self.extract_init_dict(dict(self.config)) + module_names, _, _ = self.extract_init_dict(dict(self.config)) for name in module_names.keys(): module = getattr(self, name) if isinstance(module, torch.nn.Module): @@ -228,7 +261,7 @@ class DiffusionPipeline(ConfigMixin): Returns: `torch.device`: The torch device on which the pipeline is located. """ - module_names, _ = self.extract_init_dict(dict(self.config)) + module_names, _, _ = self.extract_init_dict(dict(self.config)) for name in module_names.keys(): module = getattr(self, name) if isinstance(module, torch.nn.Module): @@ -377,11 +410,11 @@ class DiffusionPipeline(ConfigMixin): >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens) >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") - >>> # Download pipeline, but overwrite scheduler + >>> # Use a different scheduler >>> from diffusers import LMSDiscreteScheduler - >>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler") - >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler) + >>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.scheduler = scheduler ``` """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) @@ -400,7 +433,7 @@ class DiffusionPipeline(ConfigMixin): if low_cpu_mem_usage and not is_accelerate_available(): low_cpu_mem_usage = False - logger.warn( + logger.warning( "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" @@ -428,7 +461,7 @@ class DiffusionPipeline(ConfigMixin): # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained if not os.path.isdir(pretrained_model_name_or_path): - config_dict = cls.get_config_dict( + config_dict = cls.load_config( pretrained_model_name_or_path, cache_dir=cache_dir, resume_download=resume_download, @@ -444,7 +477,7 @@ class DiffusionPipeline(ConfigMixin): allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name] # make sure we don't download flax weights - ignore_patterns = "*.msgpack" + ignore_patterns = ["*.msgpack"] if custom_pipeline is not None: allow_patterns += [CUSTOM_PIPELINE_FILE_NAME] @@ -458,6 +491,15 @@ class DiffusionPipeline(ConfigMixin): user_agent["custom_pipeline"] = custom_pipeline user_agent = http_user_agent(user_agent) + if is_safetensors_available(): + info = model_info( + pretrained_model_name_or_path, + use_auth_token=use_auth_token, + revision=revision, + ) + if is_safetensors_compatible(info): + ignore_patterns.append("*.bin") + # download all allow_patterns cached_folder = snapshot_download( pretrained_model_name_or_path, @@ -474,13 +516,21 @@ class DiffusionPipeline(ConfigMixin): else: cached_folder = pretrained_model_name_or_path - config_dict = cls.get_config_dict(cached_folder) + config_dict = cls.load_config(cached_folder) # 2. Load the pipeline class, if using custom module then load it from the hub # if we load from explicit class, let's use it if custom_pipeline is not None: + if custom_pipeline.endswith(".py"): + path = Path(custom_pipeline) + # decompose into folder & file + file_name = path.name + custom_pipeline = path.parent.absolute() + else: + file_name = CUSTOM_PIPELINE_FILE_NAME + pipeline_class = get_class_from_dynamic_module( - custom_pipeline, module_file=CUSTOM_PIPELINE_FILE_NAME, cache_dir=custom_pipeline + custom_pipeline, module_file=file_name, cache_dir=custom_pipeline ) elif cls != DiffusionPipeline: pipeline_class = cls @@ -510,38 +560,47 @@ class DiffusionPipeline(ConfigMixin): # some modules can be passed directly to the init # in this case they are already instantiated in `kwargs` # extract them here - expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"]) + expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} - init_dict, unused_kwargs = pipeline_class.extract_init_dict(config_dict, **kwargs) + init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) + + # define init kwargs + init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict} + init_kwargs = {**init_kwargs, **passed_pipe_kwargs} + + # remove `null` components + def load_module(name, value): + if value[0] is None: + return False + if name in passed_class_obj and passed_class_obj[name] is None: + return False + return True + + init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} if len(unused_kwargs) > 0: - logger.warning(f"Keyword arguments {unused_kwargs} not recognized.") - - init_kwargs = {} + logger.warning( + f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored." + ) # import it here to avoid circular import from diffusers import pipelines # 3. Load each module in the pipeline for name, (library_name, class_name) in init_dict.items(): - if class_name is None: - # edge case for when the pipeline was saved with safety_checker=None - init_kwargs[name] = None - continue - # 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names if class_name.startswith("Flax"): class_name = class_name[4:] is_pipeline_module = hasattr(pipelines, library_name) loaded_sub_model = None - sub_model_should_be_defined = True # if the model is in a pipeline module, then we load it from the pipeline if name in passed_class_obj: # 1. check that passed_class_obj has correct parent class - if not is_pipeline_module and passed_class_obj[name] is not None: + if not is_pipeline_module: library = importlib.import_module(library_name) class_obj = getattr(library, class_name) importable_classes = LOADABLE_CLASSES[library_name] @@ -557,14 +616,8 @@ class DiffusionPipeline(ConfigMixin): f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" f" {expected_class_obj}" ) - elif passed_class_obj[name] is None: - logger.warn( - f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note" - f" that this might lead to problems when using {pipeline_class} and is not recommended." - ) - sub_model_should_be_defined = False else: - logger.warn( + logger.warning( f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" " has the correct type" ) @@ -584,7 +637,7 @@ class DiffusionPipeline(ConfigMixin): importable_classes = LOADABLE_CLASSES[library_name] class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} - if loaded_sub_model is None and sub_model_should_be_defined: + if loaded_sub_model is None: load_method_name = None for class_name, class_candidate in class_candidates.items(): if class_candidate is not None and issubclass(class_obj, class_candidate): @@ -592,7 +645,10 @@ class DiffusionPipeline(ConfigMixin): if load_method_name is None: none_module = class_obj.__module__ - if none_module.startswith(DUMMY_MODULES_FOLDER) and "dummy" in none_module: + is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith( + TRANSFORMERS_DUMMY_MODULES_FOLDER + ) + if is_dummy_path and "dummy" in none_module: # call class_obj for nice error message of missing requirements class_obj() @@ -635,11 +691,13 @@ class DiffusionPipeline(ConfigMixin): # 4. Potentially add passed objects if expected missing_modules = set(expected_modules) - set(init_kwargs.keys()) - if len(missing_modules) > 0 and missing_modules <= set(passed_class_obj.keys()): + passed_modules = list(passed_class_obj.keys()) + optional_modules = pipeline_class._optional_components + if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules): for module in missing_modules: - init_kwargs[module] = passed_class_obj[module] + init_kwargs[module] = passed_class_obj.get(module, None) elif len(missing_modules) > 0: - passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) + passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs raise ValueError( f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." ) @@ -648,6 +706,14 @@ class DiffusionPipeline(ConfigMixin): model = pipeline_class(**init_kwargs) return model + @staticmethod + def _get_signature_keys(obj): + parameters = inspect.signature(obj.__init__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) + expected_modules = set(required_parameters.keys()) - set(["self"]) + return expected_modules, optional_parameters + @property def components(self) -> Dict[str, Any]: r""" @@ -664,16 +730,18 @@ class DiffusionPipeline(ConfigMixin): ... StableDiffusionInpaintPipeline, ... ) - >>> img2text = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") - >>> img2img = StableDiffusionImg2ImgPipeline(**img2text.components) - >>> inpaint = StableDiffusionInpaintPipeline(**img2text.components) + >>> text2img = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + >>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components) + >>> inpaint = StableDiffusionInpaintPipeline(**text2img.components) ``` Returns: A dictionaly containing all the modules needed to initialize the pipeline. """ - components = {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")} - expected_modules = set(inspect.signature(self.__init__).parameters.keys()) - set(["self"]) + expected_modules, optional_parameters = self._get_signature_keys(self) + components = { + k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters + } if set(components.keys()) != expected_modules: raise ValueError( @@ -699,7 +767,7 @@ class DiffusionPipeline(ConfigMixin): return pil_images - def progress_bar(self, iterable): + def progress_bar(self, iterable=None, total=None): if not hasattr(self, "_progress_bar_config"): self._progress_bar_config = {} elif not isinstance(self._progress_bar_config, dict): @@ -707,7 +775,12 @@ class DiffusionPipeline(ConfigMixin): f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." ) - return tqdm(iterable, **self._progress_bar_config) + if iterable is not None: + return tqdm(iterable, **self._progress_bar_config) + elif total is not None: + return tqdm(total=total, **self._progress_bar_config) + else: + raise ValueError("Either `total` or `iterable` has to be defined.") def set_progress_bar_config(self, **kwargs): self._progress_bar_config = kwargs diff --git a/src/diffusers/pipelines/README.md b/src/diffusers/pipelines/README.md index 2941660fa2..6ff40d3549 100644 --- a/src/diffusers/pipelines/README.md +++ b/src/diffusers/pipelines/README.md @@ -40,7 +40,7 @@ available a colab notebook to directly try them out. | [pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | *Unconditional Image Generation* | | [score_sde_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* | | [score_sde_vp](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* | -| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) +| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb) | [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) | [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) | [stochastic_karras_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | *Unconditional Image Generation* | diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index ef4d23e5e6..c5aba30204 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -15,13 +15,23 @@ else: from ..utils.dummy_pt_objects import * # noqa F403 if is_torch_available() and is_transformers_available(): + from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline from .latent_diffusion import LDMTextToImagePipeline from .stable_diffusion import ( CycleDiffusionPipeline, + StableDiffusionImageVariationPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy, StableDiffusionPipeline, + StableDiffusionUpscalePipeline, + ) + from .stable_diffusion_safe import StableDiffusionPipelineSafe + from .versatile_diffusion import ( + VersatileDiffusionDualGuidedPipeline, + VersatileDiffusionImageVariationPipeline, + VersatileDiffusionPipeline, + VersatileDiffusionTextToImagePipeline, ) from .vq_diffusion import VQDiffusionPipeline @@ -29,6 +39,7 @@ if is_transformers_available() and is_onnx_available(): from .stable_diffusion import ( OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionInpaintPipeline, + OnnxStableDiffusionInpaintPipelineLegacy, OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline, ) diff --git a/src/diffusers/pipelines/alt_diffusion/__init__.py b/src/diffusers/pipelines/alt_diffusion/__init__.py new file mode 100644 index 0000000000..09d0d9b785 --- /dev/null +++ b/src/diffusers/pipelines/alt_diffusion/__init__.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np + +import PIL +from PIL import Image + +from ...utils import BaseOutput, is_torch_available, is_transformers_available + + +@dataclass +# Copied from diffusers.pipelines.stable_diffusion.__init__.StableDiffusionPipelineOutput with Stable->Alt +class AltDiffusionPipelineOutput(BaseOutput): + """ + Output class for Alt Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, or `None` if safety checking could not be performed. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: Optional[List[bool]] + + +if is_transformers_available() and is_torch_available(): + from .modeling_roberta_series import RobertaSeriesModelWithTransformation + from .pipeline_alt_diffusion import AltDiffusionPipeline + from .pipeline_alt_diffusion_img2img import AltDiffusionImg2ImgPipeline diff --git a/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py b/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py new file mode 100644 index 0000000000..2e92314162 --- /dev/null +++ b/src/diffusers/pipelines/alt_diffusion/modeling_roberta_series.py @@ -0,0 +1,110 @@ +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import nn + +from transformers import RobertaPreTrainedModel, XLMRobertaConfig, XLMRobertaModel +from transformers.utils import ModelOutput + + +@dataclass +class TransformationModelOutput(ModelOutput): + """ + Base class for text model's outputs that also contains a pooling of the last hidden states. + + Args: + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + projection_state: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class RobertaSeriesConfig(XLMRobertaConfig): + def __init__( + self, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + project_dim=512, + pooler_fn="cls", + learn_encoder=False, + use_attention_mask=True, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + self.project_dim = project_dim + self.pooler_fn = pooler_fn + self.learn_encoder = learn_encoder + self.use_attention_mask = use_attention_mask + + +class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + base_model_prefix = "roberta" + config_class = RobertaSeriesConfig + + def __init__(self, config): + super().__init__(config) + self.roberta = XLMRobertaModel(config) + self.transformation = nn.Linear(config.hidden_size, config.project_dim) + self.post_init() + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ): + r""" """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.base_model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + projection_state = self.transformation(outputs.last_hidden_state) + + return TransformationModelOutput( + projection_state=projection_state, + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py new file mode 100644 index 0000000000..9146d45bd3 --- /dev/null +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -0,0 +1,597 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import torch + +from diffusers.utils import is_accelerate_available +from packaging import version +from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import deprecate, logging +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker +class AltDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Alt Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`RobertaSeriesModelWithTransformation`]): + Frozen text-encoder. Alt Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.RobertaSeriesModelWithTransformation), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`XLMRobertaTokenizer`): + Tokenizer of class + [XLMRobertaTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.XLMRobertaTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: RobertaSeriesModelWithTransformation, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + if isinstance(self.unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.unet.config.attention_head_dim) + + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate + # fix by only offloading self.safety_checker for now + cpu_offload(self.safety_checker.vision_model, device) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py new file mode 100644 index 0000000000..16dbd626cd --- /dev/null +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -0,0 +1,614 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import torch + +import PIL +from diffusers.utils import is_accelerate_available +from packaging import version +from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import PIL_INTERPOLATION, deprecate, logging +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker +class AltDiffusionImg2ImgPipeline(DiffusionPipeline): + r""" + Pipeline for text-guided image to image generation using Alt Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`RobertaSeriesModelWithTransformation`]): + Frozen text-encoder. Alt Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.RobertaSeriesModelWithTransformation), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`XLMRobertaTokenizer`): + Tokenizer of class + [XLMRobertaTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.XLMRobertaTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: RobertaSeriesModelWithTransformation, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + if isinstance(self.unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.unet.config.attention_head_dim) + + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate + # fix by only offloading self.safety_checker for now + cpu_offload(self.safety_checker.vision_model, device) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, strength, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + init_image = init_image.to(device=device, dtype=dtype) + init_latent_dist = self.vae.encode(init_image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = 0.18215 * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`init_image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many init images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(init_image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `init_image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) + + # add noise to latents using the timesteps + noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + init_image: Union[torch.FloatTensor, PIL.Image.Image], + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + init_image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. + `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 1. Check inputs + self.check_inputs(prompt, strength, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Preprocess image + if isinstance(init_image, PIL.Image.Image): + init_image = preprocess(init_image) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents = self.prepare_latents( + init_image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 9. Post-processing + image = self.decode_latents(latents) + + # 10. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + + # 11. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return AltDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index 6db6298329..b9e590dea6 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -89,7 +89,11 @@ class DDIMPipeline(DiffusionPipeline): generator = None # Sample gaussian noise to begin loop - image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) + if isinstance(self.unet.sample_size, int): + image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) + else: + image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size) + if self.device.type == "mps": # randn does not work reproducibly on mps image = torch.randn(image_shape, generator=generator) diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index b7194664f4..31791caf9e 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -70,14 +70,14 @@ class DDPMPipeline(DiffusionPipeline): generated images. """ message = ( - "Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler =" - " DDPMScheduler.from_config(, predict_epsilon=True)`." + "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" + " DDPMScheduler.from_pretrained(, prediction_type='epsilon')`." ) predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) if predict_epsilon is not None: new_config = dict(self.scheduler.config) - new_config["predict_epsilon"] = predict_epsilon + new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample" self.scheduler._internal_dict = FrozenDict(new_config) if generator is not None and generator.device.type != self.device.type and self.device.type != "mps": @@ -94,7 +94,11 @@ class DDPMPipeline(DiffusionPipeline): generator = None # Sample gaussian noise to begin loop - image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) + if isinstance(self.unet.sample_size, int): + image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size) + else: + image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size) + if self.device.type == "mps": # randn does not work reproducibly on mps image = torch.randn(image_shape, generator=generator) @@ -110,9 +114,7 @@ class DDPMPipeline(DiffusionPipeline): model_output = self.unet(image, t).sample # 2. compute previous image: x_t -> x_t-1 - image = self.scheduler.step( - model_output, t, image, generator=generator, predict_epsilon=predict_epsilon - ).prev_sample + image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index feb5b00d74..0e903cb836 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -60,13 +60,14 @@ class LDMTextToImagePipeline(DiffusionPipeline): ): super().__init__() self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], - height: Optional[int] = 256, - width: Optional[int] = 256, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 1.0, eta: Optional[float] = 0.0, @@ -79,9 +80,9 @@ class LDMTextToImagePipeline(DiffusionPipeline): Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to 256): + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. - width (`int`, *optional*, defaults to 256): + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -106,6 +107,9 @@ class LDMTextToImagePipeline(DiffusionPipeline): `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor if isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py index 044ff359e3..b296a4953f 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py @@ -17,12 +17,13 @@ from ...schedulers import ( LMSDiscreteScheduler, PNDMScheduler, ) +from ...utils import PIL_INTERPOLATION def preprocess(image): w, h = image.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) diff --git a/src/diffusers/pipelines/stable_diffusion/README.md b/src/diffusers/pipelines/stable_diffusion/README.md index a76e4c6682..bc30be4a7b 100644 --- a/src/diffusers/pipelines/stable_diffusion/README.md +++ b/src/diffusers/pipelines/stable_diffusion/README.md @@ -72,7 +72,7 @@ image.save("astronaut_rides_horse.png") # make sure you're logged in with `huggingface-cli login` from diffusers import StableDiffusionPipeline, DDIMScheduler -scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler") +scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler") pipe = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", @@ -91,7 +91,7 @@ image.save("astronaut_rides_horse.png") # make sure you're logged in with `huggingface-cli login` from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler -lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler") +lms = LMSDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler") pipe = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", @@ -120,7 +120,7 @@ from diffusers import CycleDiffusionPipeline, DDIMScheduler # load the pipeline # make sure you're logged in with `huggingface-cli login` model_id_or_path = "CompVis/stable-diffusion-v1-4" -scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler") +scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler") pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda") # let's download an initial image diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 6623929f86..80ac88e1f4 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -6,7 +6,14 @@ import numpy as np import PIL from PIL import Image -from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_torch_available, is_transformers_available +from ...utils import ( + BaseOutput, + is_flax_available, + is_onnx_available, + is_torch_available, + is_transformers_available, + is_transformers_version, +) @dataclass @@ -33,12 +40,19 @@ if is_transformers_available() and is_torch_available(): from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy + from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline from .safety_checker import StableDiffusionSafetyChecker +if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0"): + from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline +else: + from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline + if is_transformers_available() and is_onnx_available(): from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline from .pipeline_onnx_stable_diffusion_inpaint import OnnxStableDiffusionInpaintPipeline + from .pipeline_onnx_stable_diffusion_inpaint_legacy import OnnxStableDiffusionInpaintPipelineLegacy if is_transformers_available() and is_flax_available(): import flax @@ -49,15 +63,14 @@ if is_transformers_available() and is_flax_available(): Output class for Stable Diffusion pipelines. Args: - images (`List[PIL.Image.Image]` or `np.ndarray`) - List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, - num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + images (`np.ndarray`) + Array of shape `(batch_size, height, width, num_channels)` with images from the diffusion pipeline. nsfw_content_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content. """ - images: Union[List[PIL.Image.Image], np.ndarray] + images: np.ndarray nsfw_content_detected: List[bool] from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index dfdb58de4d..9ebbc249f6 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -20,13 +20,14 @@ import torch import PIL from diffusers.utils import is_accelerate_available +from packaging import version from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler -from ...utils import deprecate, logging +from ...utils import PIL_INTERPOLATION, deprecate, logging from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -37,7 +38,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name def preprocess(image): w, h = image.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) @@ -132,6 +133,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ + _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, @@ -142,6 +144,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): scheduler: DDIMScheduler, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, ): super().__init__() @@ -159,8 +162,8 @@ class CycleDiffusionPipeline(DiffusionPipeline): new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - if safety_checker is None: - logger.warn( + if safety_checker is None and requires_safety_checker: + logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" @@ -169,6 +172,32 @@ class CycleDiffusionPipeline(DiffusionPipeline): " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -178,6 +207,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): @@ -194,9 +224,14 @@ class CycleDiffusionPipeline(DiffusionPipeline): `attention_head_dim` must be a multiple of `slice_size`. """ if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 + if isinstance(self.unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.unet.config.attention_head_dim) + self.unet.set_attention_slice(slice_size) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing @@ -209,7 +244,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): self.enable_attention_slicing(None) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload - def enable_sequential_cpu_offload(self): + def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a @@ -220,12 +255,17 @@ class CycleDiffusionPipeline(DiffusionPipeline): else: raise ImportError("Please install accelerate via `pip install accelerate`") - device = torch.device("cuda") + device = torch.device(f"cuda:{gpu_id}") - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) + if self.safety_checker is not None: + # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate + # fix by only offloading self.safety_checker for now + cpu_offload(self.safety_checker.vision_model, device) + @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): @@ -301,7 +341,17 @@ class CycleDiffusionPipeline(DiffusionPipeline): "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - text_embeddings = self.text_encoder(text_input_ids.to(device))[0] + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = text_embeddings.shape @@ -337,7 +387,17 @@ class CycleDiffusionPipeline(DiffusionPipeline): truncation=True, return_tensors="pt", ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0] + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = uncond_embeddings.shape[1] @@ -415,7 +475,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): t_start = max(num_inference_steps - init_timestep + offset, 0) timesteps = self.scheduler.timesteps[t_start:] - return timesteps + return timesteps, num_inference_steps - t_start def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): init_image = init_image.to(device=device, dtype=dtype) @@ -468,7 +528,6 @@ class CycleDiffusionPipeline(DiffusionPipeline): return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, - **kwargs, ): r""" Function invoked when calling the pipeline for generation. @@ -548,7 +607,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.get_timesteps(num_inference_steps, strength, device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables @@ -562,66 +621,70 @@ class CycleDiffusionPipeline(DiffusionPipeline): generator = extra_step_kwargs.pop("generator", None) # 8. Denoising loop - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) - source_latent_model_input = torch.cat([source_latents] * 2) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) + source_latent_model_input = torch.cat([source_latents] * 2) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t) - # predict the noise residual - concat_latent_model_input = torch.stack( - [ - source_latent_model_input[0], - latent_model_input[0], - source_latent_model_input[1], - latent_model_input[1], - ], - dim=0, - ) - concat_text_embeddings = torch.stack( - [ - source_text_embeddings[0], - text_embeddings[0], - source_text_embeddings[1], - text_embeddings[1], - ], - dim=0, - ) - concat_noise_pred = self.unet( - concat_latent_model_input, t, encoder_hidden_states=concat_text_embeddings - ).sample + # predict the noise residual + concat_latent_model_input = torch.stack( + [ + source_latent_model_input[0], + latent_model_input[0], + source_latent_model_input[1], + latent_model_input[1], + ], + dim=0, + ) + concat_text_embeddings = torch.stack( + [ + source_text_embeddings[0], + text_embeddings[0], + source_text_embeddings[1], + text_embeddings[1], + ], + dim=0, + ) + concat_noise_pred = self.unet( + concat_latent_model_input, t, encoder_hidden_states=concat_text_embeddings + ).sample - # perform guidance - ( - source_noise_pred_uncond, - noise_pred_uncond, - source_noise_pred_text, - noise_pred_text, - ) = concat_noise_pred.chunk(4, dim=0) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - source_noise_pred = source_noise_pred_uncond + source_guidance_scale * ( - source_noise_pred_text - source_noise_pred_uncond - ) + # perform guidance + ( + source_noise_pred_uncond, + noise_pred_uncond, + source_noise_pred_text, + noise_pred_text, + ) = concat_noise_pred.chunk(4, dim=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + source_noise_pred = source_noise_pred_uncond + source_guidance_scale * ( + source_noise_pred_text - source_noise_pred_uncond + ) - # Sample source_latents from the posterior distribution. - prev_source_latents = posterior_sample( - self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs - ) - # Compute noise. - noise = compute_noise( - self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs - ) - source_latents = prev_source_latents + # Sample source_latents from the posterior distribution. + prev_source_latents = posterior_sample( + self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs + ) + # Compute noise. + noise = compute_noise( + self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs + ) + source_latents = prev_source_latents - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step( - noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs - ).prev_sample + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs + ).prev_sample - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + # call the callback, if provided + if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) # 9. Post-processing image = self.decode_latents(latents) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 02943997d9..23148dcfe2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -23,6 +23,7 @@ import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict from flax.jax_utils import unreplicate from flax.training.common_utils import shard +from packaging import version from PIL import Image from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel @@ -34,7 +35,7 @@ from ...schedulers import ( FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, ) -from ...utils import logging +from ...utils import deprecate, logging from . import FlaxStableDiffusionPipelineOutput from .safety_checker_flax import FlaxStableDiffusionSafetyChecker @@ -88,7 +89,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): self.dtype = dtype if safety_checker is None: - logger.warn( + logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" @@ -97,6 +98,27 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -106,6 +128,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) def prepare_inputs(self, prompt: Union[str, List[str]]): if not isinstance(prompt, (str, list)): @@ -160,12 +183,17 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): params: Union[Dict, FrozenDict], prng_seed: jax.random.PRNGKey, num_inference_steps: int = 50, - height: int = 512, - width: int = 512, + height: Optional[int] = None, + width: Optional[int] = None, guidance_scale: float = 7.5, latents: Optional[jnp.array] = None, debug: bool = False, + neg_prompt_ids: jnp.array = None, ): + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -177,13 +205,22 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): batch_size = prompt_ids.shape[0] max_length = prompt_ids.shape[-1] - uncond_input = self.tokenizer( - [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" - ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0] + + if neg_prompt_ids is None: + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ).input_ids + else: + uncond_input = neg_prompt_ids + uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0] context = jnp.concatenate([uncond_embeddings, text_embeddings]) - latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + latents_shape = ( + batch_size, + self.unet.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) if latents is None: latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) else: @@ -244,14 +281,14 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): params: Union[Dict, FrozenDict], prng_seed: jax.random.PRNGKey, num_inference_steps: int = 50, - height: int = 512, - width: int = 512, + height: Optional[int] = None, + width: Optional[int] = None, guidance_scale: float = 7.5, latents: jnp.array = None, return_dict: bool = True, jit: bool = False, debug: bool = False, - **kwargs, + neg_prompt_ids: jnp.array = None, ): r""" Function invoked when calling the pipeline for generation. @@ -259,9 +296,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to 512): + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -279,9 +316,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. jit (`bool`, defaults to `False`): Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. @@ -296,13 +330,36 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + if jit: images = _p_generate( - self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug + self, + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + debug, + neg_prompt_ids, ) else: images = self._generate( - prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + debug, + neg_prompt_ids, ) if self.safety_checker is not None: @@ -322,6 +379,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): images = images.reshape(num_devices, batch_size, height, width, 3) else: + images = np.asarray(images) has_nsfw_concept = False if not return_dict: @@ -333,10 +391,29 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): # TODO: maybe use a config dict instead of so many static argnums @partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9)) def _p_generate( - pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug + pipe, + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + debug, + neg_prompt_ids, ): return pipe._generate( - prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + debug, + neg_prompt_ids, ) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py index eceefea874..1b9a8ff724 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py @@ -41,6 +41,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): safety_checker: OnnxRuntimeModel feature_extractor: CLIPFeatureExtractor + _optional_components = ["safety_checker", "feature_extractor"] + def __init__( self, vae_encoder: OnnxRuntimeModel, @@ -51,6 +53,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: OnnxRuntimeModel, feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, ): super().__init__() @@ -81,6 +84,22 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + self.register_modules( vae_encoder=vae_encoder, vae_decoder=vae_decoder, @@ -91,6 +110,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.register_to_config(requires_safety_checker=requires_safety_checker) def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): r""" @@ -185,7 +205,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback_steps: Optional[int] = 1, - **kwargs, ): if isinstance(prompt, str): batch_size = 1 @@ -261,8 +280,10 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample - latents = np.array(latents) + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents = scheduler_output.prev_sample.numpy() # call the callback, if provided if callback is not None and i % callback_steps == 0: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py index 483b5fd2d3..1a878535c1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py @@ -25,7 +25,7 @@ from ...configuration_utils import FrozenDict from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from ...utils import deprecate, logging +from ...utils import PIL_INTERPOLATION, deprecate, logging from . import StableDiffusionPipelineOutput @@ -35,7 +35,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name def preprocess(image): w, h = image.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) return 2.0 * image - 1.0 @@ -77,6 +77,8 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): safety_checker: OnnxRuntimeModel feature_extractor: CLIPFeatureExtractor + _optional_components = ["safety_checker", "feature_extractor"] + def __init__( self, vae_encoder: OnnxRuntimeModel, @@ -87,6 +89,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: OnnxRuntimeModel, feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, ): super().__init__() @@ -117,7 +120,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) - if safety_checker is None: + if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" @@ -127,6 +130,12 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + self.register_modules( vae_encoder=vae_encoder, vae_decoder=vae_decoder, @@ -137,6 +146,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): @@ -231,7 +241,6 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback_steps: Optional[int] = 1, - **kwargs, ): r""" Function invoked when calling the pipeline for generation. @@ -401,8 +410,10 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample - latents = latents.numpy() + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents = scheduler_output.prev_sample.numpy() # call the callback, if provided if callback is not None and i % callback_steps == 0: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py index 8e5c201319..930d61de99 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py @@ -25,7 +25,7 @@ from ...configuration_utils import FrozenDict from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from ...utils import deprecate, logging +from ...utils import PIL_INTERPOLATION, deprecate, logging from . import StableDiffusionPipelineOutput @@ -44,7 +44,7 @@ def prepare_mask_and_masked_image(image, mask, latents_shape): image_mask = np.array(mask.convert("L").resize((latents_shape[1] * 8, latents_shape[0] * 8))) masked_image = image * (image_mask < 127.5) - mask = mask.resize((latents_shape[1], latents_shape[0]), PIL.Image.NEAREST) + mask = mask.resize((latents_shape[1], latents_shape[0]), PIL_INTERPOLATION["nearest"]) mask = np.array(mask.convert("L")) mask = mask.astype(np.float32) / 255.0 mask = mask[None, None] @@ -90,6 +90,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): safety_checker: OnnxRuntimeModel feature_extractor: CLIPFeatureExtractor + _optional_components = ["safety_checker", "feature_extractor"] + def __init__( self, vae_encoder: OnnxRuntimeModel, @@ -100,6 +102,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: OnnxRuntimeModel, feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, ): super().__init__() logger.info("`OnnxStableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") @@ -131,7 +134,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) - if safety_checker is None: + if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" @@ -141,6 +144,12 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + self.register_modules( vae_encoder=vae_encoder, vae_decoder=vae_decoder, @@ -151,6 +160,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): @@ -236,8 +246,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): prompt: Union[str, List[str]], image: PIL.Image.Image, mask_image: PIL.Image.Image, - height: int = 512, - width: int = 512, + height: Optional[int] = 512, + width: Optional[int] = 512, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -249,7 +259,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback_steps: Optional[int] = 1, - **kwargs, ): r""" Function invoked when calling the pipeline for generation. @@ -312,6 +321,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): @@ -408,9 +418,9 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): # expand the latents if we are doing classifier free guidance latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents # concat latents, mask, masked_image_latnets in the channel dimension - latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1) latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) latent_model_input = latent_model_input.cpu().numpy() + latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1) # predict the noise residual timestep = np.array([t], dtype=timestep_dtype) @@ -424,8 +434,10 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample - latents = latents.numpy() + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents = scheduler_output.prev_sample.numpy() # call the callback, if provided if callback is not None and i % callback_steps == 0: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py new file mode 100644 index 0000000000..2f990651a4 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py @@ -0,0 +1,456 @@ +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import torch + +import PIL +from transformers import CLIPFeatureExtractor, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...onnx_utils import OnnxRuntimeModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import deprecate, logging +from . import StableDiffusionPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def preprocess(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask, scale_factor=8): + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + return mask + + +class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. This is a *legacy feature* for Onnx pipelines to + provide compatibility with StableDiffusionInpaintPipelineLegacy and may be removed in the future. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + vae_encoder: OnnxRuntimeModel + vae_decoder: OnnxRuntimeModel + text_encoder: OnnxRuntimeModel + tokenizer: CLIPTokenizer + unet: OnnxRuntimeModel + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + safety_checker: OnnxRuntimeModel + feature_extractor: CLIPFeatureExtractor + + def __init__( + self, + vae_encoder: OnnxRuntimeModel, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt + def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + uncond_embeddings = np.repeat(uncond_embeddings, num_images_per_prompt, axis=0) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def __call__( + self, + prompt: Union[str, List[str]], + init_image: Union[np.ndarray, PIL.Image.Image], + mask_image: Union[np.ndarray, PIL.Image.Image], + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[np.random.RandomState] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: Optional[int] = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + init_image (`nd.ndarray` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. This is the image whose masked region will be inpainted. + mask_image (`nd.ndarray` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.uu + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. + `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The + number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added + noise will be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (?) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`np.random.RandomState`, *optional*): + A np.random.RandomState to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if generator is None: + generator = np.random + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + if isinstance(init_image, PIL.Image.Image): + init_image = preprocess(init_image) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + text_embeddings = self._encode_prompt( + prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + latents_dtype = text_embeddings.dtype + init_image = init_image.astype(latents_dtype) + + # encode the init image into latents and scale the latents + init_latents = self.vae_encoder(sample=init_image)[0] + init_latents = 0.18215 * init_latents + + # Expand init_latents for batch_size and num_images_per_prompt + init_latents = np.concatenate([init_latents] * num_images_per_prompt, axis=0) + init_latents_orig = init_latents + + # preprocess mask + if not isinstance(mask_image, np.ndarray): + mask_image = preprocess_mask(mask_image, 8) + mask_image = mask_image.astype(latents_dtype) + mask = np.concatenate([mask_image] * num_images_per_prompt, axis=0) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and init_image should be the same size!") + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps.numpy()[-init_timestep] + timesteps = np.array([timesteps] * batch_size * num_images_per_prompt) + + # add noise to latents using the timesteps + noise = generator.randn(*init_latents.shape).astype(latents_dtype) + init_latents = self.scheduler.add_noise( + torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) + ) + init_latents = init_latents.numpy() + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (?) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to ? in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + latents = init_latents + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].numpy() + + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ).prev_sample + + latents = latents.numpy() + + init_latents_proper = self.scheduler.add_noise( + torch.from_numpy(init_latents_orig), torch.from_numpy(noise), torch.from_numpy(np.array([t])) + ) + + init_latents_proper = init_latents_proper.numpy() + + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + latents = 1 / 0.18215 * latents + # image = self.vae_decoder(latent_sample=latents)[0] + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="np" + ).pixel_values.astype(image.dtype) + # There will throw an error if use safety_checker batchsize>1 + images, has_nsfw_concept = [], [] + for i in range(image.shape[0]): + image_i, has_nsfw_concept_i = self.safety_checker( + clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1] + ) + images.append(image_i) + has_nsfw_concept.append(has_nsfw_concept_i[0]) + image = np.concatenate(images) + else: + has_nsfw_concept = None + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index e635347293..afaef6f481 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -18,6 +18,7 @@ from typing import Callable, List, Optional, Union import torch from diffusers.utils import is_accelerate_available +from packaging import version from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict @@ -66,6 +67,7 @@ class StableDiffusionPipeline(DiffusionPipeline): feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ + _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, @@ -83,6 +85,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, ): super().__init__() @@ -113,8 +116,8 @@ class StableDiffusionPipeline(DiffusionPipeline): new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) - if safety_checker is None: - logger.warn( + if safety_checker is None and requires_safety_checker: + logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" @@ -123,6 +126,33 @@ class StableDiffusionPipeline(DiffusionPipeline): " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -132,6 +162,8 @@ class StableDiffusionPipeline(DiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_xformers_memory_efficient_attention(self): r""" @@ -165,9 +197,14 @@ class StableDiffusionPipeline(DiffusionPipeline): `attention_head_dim` must be a multiple of `slice_size`. """ if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 + if isinstance(self.unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.unet.config.attention_head_dim) + self.unet.set_attention_slice(slice_size) def disable_attention_slicing(self): @@ -178,7 +215,23 @@ class StableDiffusionPipeline(DiffusionPipeline): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) - def enable_sequential_cpu_offload(self): + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a @@ -189,12 +242,17 @@ class StableDiffusionPipeline(DiffusionPipeline): else: raise ImportError("Please install accelerate via `pip install accelerate`") - device = torch.device("cuda") + device = torch.device(f"cuda:{gpu_id}") - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) + if self.safety_checker is not None: + # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate + # fix by only offloading self.safety_checker for now + cpu_offload(self.safety_checker.vision_model, device) + @property def _execution_device(self): r""" @@ -248,7 +306,17 @@ class StableDiffusionPipeline(DiffusionPipeline): "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - text_embeddings = self.text_encoder(text_input_ids.to(device))[0] + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = text_embeddings.shape @@ -284,7 +352,17 @@ class StableDiffusionPipeline(DiffusionPipeline): truncation=True, return_tensors="pt", ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0] + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = uncond_embeddings.shape[1] @@ -349,7 +427,7 @@ class StableDiffusionPipeline(DiffusionPipeline): ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // 8, width // 8) + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if latents is None: if device.type == "mps": # randn does not work reproducibly on mps @@ -369,8 +447,8 @@ class StableDiffusionPipeline(DiffusionPipeline): def __call__( self, prompt: Union[str, List[str]], - height: int = 512, - width: int = 512, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -382,7 +460,6 @@ class StableDiffusionPipeline(DiffusionPipeline): return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, - **kwargs, ): r""" Function invoked when calling the pipeline for generation. @@ -390,9 +467,9 @@ class StableDiffusionPipeline(DiffusionPipeline): Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to 512): + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -438,6 +515,9 @@ class StableDiffusionPipeline(DiffusionPipeline): list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) @@ -476,25 +556,29 @@ class StableDiffusionPipeline(DiffusionPipeline): extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + # call the callback, if provided + if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) # 8. Post-processing image = self.decode_latents(latents) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py new file mode 100644 index 0000000000..e64a572a87 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -0,0 +1,481 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import torch + +import PIL +from diffusers.utils import is_accelerate_available +from packaging import version +from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import deprecate, logging +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableDiffusionImageVariationPipeline(DiffusionPipeline): + r""" + Pipeline to generate variations from an input image using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + image_encoder: CLIPVisionModelWithProjection, + unet: UNet2DConditionModel, + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warn( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + if isinstance(self.unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.unet.config.attention_head_dim) + + self.unet.set_attention_slice(slice_size) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.image_encoder, self.vae, self.safety_checker]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(images=image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + uncond_embeddings = torch.zeros_like(image_embeddings) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([uncond_embeddings, image_embeddings]) + + return image_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, image, height, width, callback_steps): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + f"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `list` but is {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + The image or images to guide the image generation. If you provide a tensor, it needs to comply with the + configuration of + [this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) + `CLIPFeatureExtractor` + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(image, height, width, callback_steps) + + # 2. Define call parameters + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input image + image_embeddings = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 9df800dc2d..a25acc0bd1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -20,6 +20,7 @@ import torch import PIL from diffusers.utils import is_accelerate_available +from packaging import version from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict @@ -33,7 +34,7 @@ from ...schedulers import ( LMSDiscreteScheduler, PNDMScheduler, ) -from ...utils import deprecate, logging +from ...utils import PIL_INTERPOLATION, deprecate, logging from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -44,7 +45,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name def preprocess(image): w, h = image.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) @@ -78,6 +79,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ + _optional_components = ["safety_checker", "feature_extractor"] # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__ def __init__( @@ -96,6 +98,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, ): super().__init__() @@ -126,8 +129,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) - if safety_checker is None: - logger.warn( + if safety_checker is None and requires_safety_checker: + logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" @@ -136,6 +139,33 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -145,6 +175,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): @@ -161,9 +193,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): `attention_head_dim` must be a multiple of `slice_size`. """ if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 + if isinstance(self.unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.unet.config.attention_head_dim) + self.unet.set_attention_slice(slice_size) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing @@ -176,7 +213,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): self.enable_attention_slicing(None) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload - def enable_sequential_cpu_offload(self): + def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a @@ -187,12 +224,17 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): else: raise ImportError("Please install accelerate via `pip install accelerate`") - device = torch.device("cuda") + device = torch.device(f"cuda:{gpu_id}") - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) + if self.safety_checker is not None: + # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate + # fix by only offloading self.safety_checker for now + cpu_offload(self.safety_checker.vision_model, device) + @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): @@ -268,7 +310,17 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - text_embeddings = self.text_encoder(text_input_ids.to(device))[0] + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = text_embeddings.shape @@ -304,7 +356,17 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): truncation=True, return_tensors="pt", ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0] + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = uncond_embeddings.shape[1] @@ -380,7 +442,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): t_start = max(num_inference_steps - init_timestep + offset, 0) timesteps = self.scheduler.timesteps[t_start:] - return timesteps + return timesteps, num_inference_steps - t_start def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): init_image = init_image.to(device=device, dtype=dtype) @@ -431,7 +493,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, - **kwargs, ): r""" Function invoked when calling the pipeline for generation. @@ -510,7 +571,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.get_timesteps(num_inference_steps, strength, device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables @@ -522,25 +583,29 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Denoising loop - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + # call the callback, if provided + if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) # 9. Post-processing image = self.decode_latents(latents) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index a122723eee..6cb2766bc2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -20,6 +20,7 @@ import torch import PIL from diffusers.utils import is_accelerate_available +from packaging import version from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict @@ -35,16 +36,88 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name def prepare_mask_and_masked_image(image, mask): - image = np.array(image.convert("RGB")) - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + """ + Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be + converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the + ``image`` and ``1`` for the ``mask``. - mask = np.array(mask.convert("L")) - mask = mask.astype(np.float32) / 255.0 - mask = mask[None, None] - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - mask = torch.from_numpy(mask) + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") + + # Batch single image + if image.ndim == 3: + assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + if isinstance(image, PIL.Image.Image): + image = np.array(image.convert("RGB")) + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + if isinstance(mask, PIL.Image.Image): + mask = np.array(mask.convert("L")) + mask = mask.astype(np.float32) / 255.0 + mask = mask[None, None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) masked_image = image * (mask < 0.5) @@ -78,6 +151,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ + _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, @@ -88,6 +162,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, ): super().__init__() @@ -119,8 +194,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): new_config["skip_prk_steps"] = True scheduler._internal_dict = FrozenDict(new_config) - if safety_checker is None: - logger.warn( + if safety_checker is None and requires_safety_checker: + logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" @@ -129,6 +204,33 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -138,6 +240,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): @@ -154,9 +258,14 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): `attention_head_dim` must be a multiple of `slice_size`. """ if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 + if isinstance(self.unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.unet.config.attention_head_dim) + self.unet.set_attention_slice(slice_size) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing @@ -169,7 +278,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): self.enable_attention_slicing(None) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload - def enable_sequential_cpu_offload(self): + def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a @@ -180,12 +289,17 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): else: raise ImportError("Please install accelerate via `pip install accelerate`") - device = torch.device("cuda") + device = torch.device(f"cuda:{gpu_id}") - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) + if self.safety_checker is not None: + # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate + # fix by only offloading self.safety_checker for now + cpu_offload(self.safety_checker.vision_model, device) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention def enable_xformers_memory_efficient_attention(self): r""" @@ -261,7 +375,17 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - text_embeddings = self.text_encoder(text_input_ids.to(device))[0] + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = text_embeddings.shape @@ -297,7 +421,17 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): truncation=True, return_tensors="pt", ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0] + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = uncond_embeddings.shape[1] @@ -367,7 +501,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // 8, width // 8) + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if latents is None: if device.type == "mps": # randn does not work reproducibly on mps @@ -389,7 +523,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision - mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8)) + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) mask = mask.to(device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype) @@ -417,8 +553,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): prompt: Union[str, List[str]], image: Union[torch.FloatTensor, PIL.Image.Image], mask_image: Union[torch.FloatTensor, PIL.Image.Image], - height: int = 512, - width: int = 512, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -430,7 +566,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, - **kwargs, ): r""" Function invoked when calling the pipeline for generation. @@ -446,9 +581,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. - height (`int`, *optional*, defaults to 512): + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -494,6 +629,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs self.check_inputs(prompt, height, width, callback_steps) @@ -517,7 +655,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps_tensor = self.scheduler.timesteps + timesteps = self.scheduler.timesteps # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels @@ -561,29 +699,32 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 10. Denoising loop - for i, t in enumerate(self.progress_bar(timesteps_tensor)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - # concat latents, mask, masked_image_latents in the channel dimension - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + # call the callback, if provided + if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) # 11. Post-processing image = self.decode_latents(latents) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 86d879eaa8..2440b6d5ad 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -20,6 +20,7 @@ import torch import PIL from diffusers.utils import is_accelerate_available +from packaging import version from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict @@ -33,7 +34,7 @@ from ...schedulers import ( LMSDiscreteScheduler, PNDMScheduler, ) -from ...utils import deprecate, logging +from ...utils import PIL_INTERPOLATION, deprecate, logging from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -44,18 +45,18 @@ logger = logging.get_logger(__name__) def preprocess_image(image): w, h = image.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) return 2.0 * image - 1.0 -def preprocess_mask(mask): +def preprocess_mask(mask, scale_factor=8): mask = mask.convert("L") w, h = mask.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST) + mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) mask = np.array(mask).astype(np.float32) / 255.0 mask = np.tile(mask, (4, 1, 1)) mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? @@ -91,6 +92,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ + _optional_components = ["safety_checker", "feature_extractor"] # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__ def __init__( @@ -109,6 +111,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, ): super().__init__() @@ -139,8 +142,8 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) - if safety_checker is None: - logger.warn( + if safety_checker is None and requires_safety_checker: + logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" @@ -149,6 +152,33 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -158,6 +188,8 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): @@ -174,9 +206,14 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): `attention_head_dim` must be a multiple of `slice_size`. """ if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 + if isinstance(self.unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.unet.config.attention_head_dim) + self.unet.set_attention_slice(slice_size) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing @@ -189,7 +226,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): self.enable_attention_slicing(None) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload - def enable_sequential_cpu_offload(self): + def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a @@ -200,12 +237,17 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): else: raise ImportError("Please install accelerate via `pip install accelerate`") - device = torch.device("cuda") + device = torch.device(f"cuda:{gpu_id}") - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) + if self.safety_checker is not None: + # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate + # fix by only offloading self.safety_checker for now + cpu_offload(self.safety_checker.vision_model, device) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention def enable_xformers_memory_efficient_attention(self): r""" @@ -281,7 +323,17 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - text_embeddings = self.text_encoder(text_input_ids.to(device))[0] + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = text_embeddings.shape @@ -317,7 +369,17 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): truncation=True, return_tensors="pt", ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0] + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = uncond_embeddings.shape[1] @@ -395,7 +457,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): t_start = max(num_inference_steps - init_timestep + offset, 0) timesteps = self.scheduler.timesteps[t_start:] - return timesteps + return timesteps, num_inference_steps - t_start def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator): init_image = init_image.to(device=self.device, dtype=dtype) @@ -430,7 +492,6 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, - **kwargs, ): r""" Function invoked when calling the pipeline for generation. @@ -512,11 +573,11 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): init_image = preprocess_image(init_image) if not isinstance(mask_image, torch.FloatTensor): - mask_image = preprocess_mask(mask_image) + mask_image = preprocess_mask(mask_image, self.vae_scale_factor) # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.get_timesteps(num_inference_steps, strength, device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables @@ -533,29 +594,33 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 9. Denoising loop - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) - latents = (init_latents_proper * mask) + (latents * (1 - mask)) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + # call the callback, if provided + if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) # 10. Post-processing image = self.decode_latents(latents) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py new file mode 100644 index 0000000000..c9c238ce9a --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -0,0 +1,555 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import torch + +import PIL +from diffusers.utils import is_accelerate_available +from transformers import CLIPTextModel, CLIPTokenizer + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def preprocess(image): + # resize to multiple of 64 + width, height = image.size + width = width - width % 64 + height = height - height % 64 + image = image.resize((width, height)) + + image = np.array(image.convert("RGB")) + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + return image + + +class StableDiffusionUpscalePipeline(DiffusionPipeline): + r""" + Pipeline for text-guided image super-resolution using Stable Diffusion 2. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + low_res_scheduler ([`SchedulerMixin`]): + A scheduler used to add initial noise to the low res conditioning image. It must be an instance of + [`DDPMScheduler`]. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + low_res_scheduler: DDPMScheduler, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + max_noise_level: int = 350, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + low_res_scheduler=low_res_scheduler, + scheduler=scheduler, + ) + self.register_to_config(max_noise_level=max_noise_level) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + if isinstance(self.unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.unet.config.attention_head_dim) + + self.unet.set_attention_slice(slice_size) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents with 0.18215->0.08333 + def decode_latents(self, latents): + latents = 1 / 0.08333 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def check_inputs(self, prompt, image, noise_level, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}" + ) + + # verify batch size of prompt and image are same if image is a list or tensor + if isinstance(image, list) or isinstance(image, torch.Tensor): + if isinstance(prompt, str): + batch_size = 1 + else: + batch_size = len(prompt) + if isinstance(image, list): + image_batch_size = len(image) + else: + image_batch_size = image.shape[0] + if batch_size != image_batch_size: + raise ValueError( + f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}." + " Please make sure that passed `prompt` matches the batch size of `image`." + ) + + # check noise level + if noise_level > self.config.max_noise_level: + raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height, width) + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]], + num_inference_steps: int = 75, + guidance_scale: float = 9.0, + noise_level: int = 20, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`): + `Image`, or tensor representing an image batch which will be upscaled. * + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # 1. Check inputs + self.check_inputs(prompt, image, noise_level, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Preprocess image + image = [image] if isinstance(image, PIL.Image.Image) else image + if isinstance(image, list): + image = [preprocess(img) for img in image] + image = torch.cat(image, dim=0) + image = image.to(dtype=text_embeddings.dtype, device=device) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Add noise to image + noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) + if device.type == "mps": + # randn does not work reproducibly on mps + noise = torch.randn(image.shape, generator=generator, device="cpu", dtype=text_embeddings.dtype).to(device) + else: + noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings.dtype) + image = self.low_res_scheduler.add_noise(image, noise, noise_level) + image = torch.cat([image] * 2) if do_classifier_free_guidance else image + noise_level = torch.cat([noise_level] * 2) if do_classifier_free_guidance else noise_level + + # 6. Prepare latent variables + height, width = image.shape[2:] + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + + # 7. Check that sizes of image and latents match + num_channels_image = image.shape[1] + if num_channels_latents + num_channels_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_image`: {num_channels_image} " + f" = {num_channels_latents+num_channels_image}. Please verify the config of" + " `pipeline.unet` or your `image` input." + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = torch.cat([latent_model_input, image], dim=1) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, t, encoder_hidden_states=text_embeddings, class_labels=noise_level + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 10. Post-processing + # make sure the VAE is in float32 mode, as it overflows in float16 + self.vae.to(dtype=torch.float32) + image = self.decode_latents(latents.float()) + + # 11. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/stable_diffusion_safe/__init__.py b/src/diffusers/pipelines/stable_diffusion_safe/__init__.py new file mode 100644 index 0000000000..59ff61fa3b --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion_safe/__init__.py @@ -0,0 +1,72 @@ +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional, Union + +import numpy as np + +import PIL +from PIL import Image + +from ...utils import BaseOutput, is_torch_available, is_transformers_available + + +@dataclass +class SafetyConfig(object): + WEAK = { + "sld_warmup_steps": 15, + "sld_guidance_scale": 20, + "sld_threshold": 0.0, + "sld_momentum_scale": 0.0, + "sld_mom_beta": 0.0, + } + MEDIUM = { + "sld_warmup_steps": 10, + "sld_guidance_scale": 1000, + "sld_threshold": 0.01, + "sld_momentum_scale": 0.3, + "sld_mom_beta": 0.4, + } + STRONG = { + "sld_warmup_steps": 7, + "sld_guidance_scale": 2000, + "sld_threshold": 0.025, + "sld_momentum_scale": 0.5, + "sld_mom_beta": 0.7, + } + MAX = { + "sld_warmup_steps": 0, + "sld_guidance_scale": 5000, + "sld_threshold": 1.0, + "sld_momentum_scale": 0.5, + "sld_mom_beta": 0.7, + } + + +@dataclass +class StableDiffusionSafePipelineOutput(BaseOutput): + """ + Output class for Safe Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, or `None` if safety checking could not be performed. + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images that were flagged by the safety checker any may contain "not-safe-for-work" + (nsfw) content, or `None` if no safety check was performed or no images were flagged. + applied_safety_concept (`str`) + The safety concept that was applied for safety guidance, or `None` if safety guidance was disabled + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: Optional[List[bool]] + unsafe_images: Optional[Union[List[PIL.Image.Image], np.ndarray]] + applied_safety_concept: Optional[str] + + +if is_transformers_available() and is_torch_available(): + from .pipeline_stable_diffusion_safe import StableDiffusionPipelineSafe + from .safety_checker import SafeStableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py new file mode 100644 index 0000000000..7f08e40103 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -0,0 +1,764 @@ +import inspect +import warnings +from typing import Callable, List, Optional, Union + +import numpy as np +import torch + +from packaging import version +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from ...utils import deprecate, is_accelerate_available, logging +from . import StableDiffusionSafePipelineOutput +from .safety_checker import SafeStableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableDiffusionPipelineSafe(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Safe Latent Diffusion. + + The implementation is based on the [`StableDiffusionPipeline`] + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[ + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + ], + safety_checker: SafeStableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + ): + super().__init__() + safety_concept: Optional[str] = ( + "an image showing hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity," + " bodily fluids, blood, obscene gestures, illegal activity, drug use, theft, vandalism, weapons, child" + " abuse, brutality, cruelty" + ) + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self._safety_text_concept = safety_concept + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + @property + def safety_concept(self): + r""" + Getter method for the safety concept used with SLD + + Returns: + `str`: The text describing the safety concept + """ + return self._safety_text_concept + + @safety_concept.setter + def safety_concept(self, concept): + r""" + Setter method for the safety concept used with SLD + + Args: + concept (`str`): + The text of the new safety concept + """ + self._safety_text_concept = concept + + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device("cuda") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + enable_safety_guidance, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = text_embeddings[0] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = uncond_embeddings[0] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # Encode the safety concept text + if enable_safety_guidance: + safety_concept_input = self.tokenizer( + [self._safety_text_concept], + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + safety_embeddings = self.text_encoder(safety_concept_input.input_ids.to(self.device))[0] + + # duplicate safety embeddings for each generation per prompt, using mps friendly method + seq_len = safety_embeddings.shape[1] + safety_embeddings = safety_embeddings.repeat(batch_size, num_images_per_prompt, 1) + safety_embeddings = safety_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance + sld, we need to do three forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing three forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings, safety_embeddings]) + + else: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def run_safety_checker(self, image, device, dtype, enable_safety_guidance): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + flagged_images = None + if any(has_nsfw_concept): + logger.warning( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead." + f" {'You may look at this images in the `unsafe_images` variable of the output at your own discretion.' if enable_safety_guidance else 'Try again with a different prompt and/or seed.'} " + ) + flagged_images = np.zeros((2, *image.shape[1:])) + for idx, has_nsfw_concept in enumerate(has_nsfw_concept): + if has_nsfw_concept: + flagged_images[idx] = image[idx] + image[idx] = np.zeros(image[idx].shape) # black image + else: + has_nsfw_concept = None + flagged_images = None + return image, has_nsfw_concept, flagged_images + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs(self, prompt, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def perform_safety_guidance( + self, + enable_safety_guidance, + safety_momentum, + noise_guidance, + noise_pred_out, + i, + sld_guidance_scale, + sld_warmup_steps, + sld_threshold, + sld_momentum_scale, + sld_mom_beta, + ): + # Perform SLD guidance + if enable_safety_guidance: + if safety_momentum is None: + safety_momentum = torch.zeros_like(noise_guidance) + noise_pred_text, noise_pred_uncond = noise_pred_out[0], noise_pred_out[1] + noise_pred_safety_concept = noise_pred_out[2] + + # Equation 6 + scale = torch.clamp(torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0) + + # Equation 6 + safety_concept_scale = torch.where( + (noise_pred_text - noise_pred_safety_concept) >= sld_threshold, torch.zeros_like(scale), scale + ) + + # Equation 4 + noise_guidance_safety = torch.mul((noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale) + + # Equation 7 + noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum + + # Equation 8 + safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety + + if i >= sld_warmup_steps: # Warmup + # Equation 3 + noise_guidance = noise_guidance - noise_guidance_safety + return noise_guidance, safety_momentum + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + sld_guidance_scale: Optional[float] = 1000, + sld_warmup_steps: Optional[int] = 10, + sld_threshold: Optional[float] = 0.01, + sld_momentum_scale: Optional[float] = 0.3, + sld_mom_beta: Optional[float] = 0.4, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + sld_guidance_scale (`float`, *optional*, defaults to 1000): + Safe latent guidance as defined in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105). + `sld_guidance_scale` is defined as sS of Eq. 6. If set to be less than 1, safety guidance will be + disabled. + sld_warmup_steps (`int`, *optional*, defaults to 10): + Number of warmup steps for safety guidance. SLD will only be applied for diffusion steps greater than + `sld_warmup_steps`. `sld_warmup_steps` is defined as `delta` of [Safe Latent + Diffusion](https://arxiv.org/abs/2211.05105). + sld_threshold (`float`, *optional*, defaults to 0.01): + Threshold that separates the hyperplane between appropriate and inappropriate images. `sld_threshold` + is defined as `lamda` of Eq. 5 in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105). + sld_momentum_scale (`float`, *optional*, defaults to 0.3): + Scale of the SLD momentum to be added to the safety guidance at each diffusion step. If set to 0.0 + momentum will be disabled. Momentum is already built up during warmup, i.e. for diffusion steps smaller + than `sld_warmup_steps`. `sld_momentum_scale` is defined as `sm` of Eq. 7 in [Safe Latent + Diffusion](https://arxiv.org/abs/2211.05105). + sld_mom_beta (`float`, *optional*, defaults to 0.4): + Defines how safety guidance momentum builds up. `sld_mom_beta` indicates how much of the previous + momentum will be kept. Momentum is already built up during warmup, i.e. for diffusion steps smaller + than `sld_warmup_steps`. `sld_mom_beta` is defined as `beta m` of Eq. 8 in [Safe Latent + Diffusion](https://arxiv.org/abs/2211.05105). + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + enable_safety_guidance = sld_guidance_scale > 1.0 and do_classifier_free_guidance + if not enable_safety_guidance: + warnings.warn("Safety checker disabled!") + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + safety_momentum = None + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * (3 if enable_safety_guidance else 2)) + if do_classifier_free_guidance + else latents + ) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_out = noise_pred.chunk((3 if enable_safety_guidance else 2)) + noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1] + + # default classifier free guidance + noise_guidance = noise_pred_text - noise_pred_uncond + + # Perform SLD guidance + if enable_safety_guidance: + if safety_momentum is None: + safety_momentum = torch.zeros_like(noise_guidance) + noise_pred_safety_concept = noise_pred_out[2] + + # Equation 6 + scale = torch.clamp( + torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0 + ) + + # Equation 6 + safety_concept_scale = torch.where( + (noise_pred_text - noise_pred_safety_concept) >= sld_threshold, + torch.zeros_like(scale), + scale, + ) + + # Equation 4 + noise_guidance_safety = torch.mul( + (noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale + ) + + # Equation 7 + noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum + + # Equation 8 + safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety + + if i >= sld_warmup_steps: # Warmup + # Equation 3 + noise_guidance = noise_guidance - noise_guidance_safety + + noise_pred = noise_pred_uncond + guidance_scale * noise_guidance + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept, flagged_images = self.run_safety_checker( + image, device, text_embeddings.dtype, enable_safety_guidance + ) + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + if flagged_images is not None: + flagged_images = self.numpy_to_pil(flagged_images) + + if not return_dict: + return ( + image, + has_nsfw_concept, + self._safety_text_concept if enable_safety_guidance else None, + flagged_images, + ) + + return StableDiffusionSafePipelineOutput( + images=image, + nsfw_content_detected=has_nsfw_concept, + applied_safety_concept=self._safety_text_concept if enable_safety_guidance else None, + unsafe_images=flagged_images, + ) diff --git a/src/diffusers/pipelines/stable_diffusion_safe/safety_checker.py b/src/diffusers/pipelines/stable_diffusion_safe/safety_checker.py new file mode 100644 index 0000000000..f9dbf51e86 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion_safe/safety_checker.py @@ -0,0 +1,110 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + +from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel + +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +def cosine_distance(image_embeds, text_embeds): + normalized_image_embeds = nn.functional.normalize(image_embeds) + normalized_text_embeds = nn.functional.normalize(text_embeds) + return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) + + +class SafeStableDiffusionSafetyChecker(PreTrainedModel): + config_class = CLIPConfig + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPConfig): + super().__init__(config) + + self.vision_model = CLIPVisionModel(config.vision_config) + self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) + + self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) + self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) + + self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) + self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) + + @torch.no_grad() + def forward(self, clip_input, images): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() + cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() + + result = [] + batch_size = image_embeds.shape[0] + for i in range(batch_size): + result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} + + # increase this value to create a stronger `nfsw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + for concept_idx in range(len(special_cos_dist[0])): + concept_cos = special_cos_dist[i][concept_idx] + concept_threshold = self.special_care_embeds_weights[concept_idx].item() + result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["special_scores"][concept_idx] > 0: + result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) + adjustment = 0.01 + + for concept_idx in range(len(cos_dist[0])): + concept_cos = cos_dist[i][concept_idx] + concept_threshold = self.concept_embeds_weights[concept_idx].item() + result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["concept_scores"][concept_idx] > 0: + result_img["bad_concepts"].append(concept_idx) + + result.append(result_img) + + has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] + + return images, has_nsfw_concepts + + @torch.no_grad() + def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) + cos_dist = cosine_distance(image_embeds, self.concept_embeds) + + # increase this value to create a stronger `nsfw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment + # special_scores = special_scores.round(decimals=3) + special_care = torch.any(special_scores > 0, dim=1) + special_adjustment = special_care * 0.01 + special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) + + concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment + # concept_scores = concept_scores.round(decimals=3) + has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) + + return images, has_nsfw_concepts diff --git a/src/diffusers/pipelines/versatile_diffusion/__init__.py b/src/diffusers/pipelines/versatile_diffusion/__init__.py new file mode 100644 index 0000000000..1d2caa7e23 --- /dev/null +++ b/src/diffusers/pipelines/versatile_diffusion/__init__.py @@ -0,0 +1,16 @@ +from ...utils import is_torch_available, is_transformers_available, is_transformers_version + + +if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0"): + from .modeling_text_unet import UNetFlatConditionModel + from .pipeline_versatile_diffusion import VersatileDiffusionPipeline + from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline + from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline + from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline +else: + from ...utils.dummy_torch_and_transformers_objects import ( + VersatileDiffusionDualGuidedPipeline, + VersatileDiffusionImageVariationPipeline, + VersatileDiffusionPipeline, + VersatileDiffusionTextToImagePipeline, + ) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py new file mode 100644 index 0000000000..37a79b5c1b --- /dev/null +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -0,0 +1,1135 @@ +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...modeling_utils import ModelMixin +from ...models.attention import DualTransformer2DModel, Transformer2DModel +from ...models.embeddings import TimestepEmbedding, Timesteps +from ...models.unet_2d_condition import UNet2DConditionOutput +from ...utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, +): + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlockFlat": + return DownBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + ) + elif down_block_type == "CrossAttnDownBlockFlat": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockFlat") + return CrossAttnDownBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + ) + raise ValueError(f"{down_block_type} is not supported.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlockFlat": + return UpBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + ) + elif up_block_type == "CrossAttnUpBlockFlat": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockFlat") + return CrossAttnUpBlockFlat( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + ) + raise ValueError(f"{up_block_type} is not supported.") + + +# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat +class UNetFlatConditionModel(ModelMixin, ConfigMixin): + r""" + UNetFlatConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a + timestep and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",)`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlockFlat", + "CrossAttnDownBlockFlat", + "CrossAttnDownBlockFlat", + "DownBlockFlat", + ), + up_block_types: Tuple[str] = ( + "UpBlockFlat", + "CrossAttnUpBlockFlat", + "CrossAttnUpBlockFlat", + "CrossAttnUpBlockFlat", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + num_class_embeds: Optional[int] = None, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = LinearMultiDim(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlockFlatCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift="default", + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + def set_attention_slice(self, slice_size): + head_dims = self.config.attention_head_dim + head_dims = [head_dims] if isinstance(head_dims, int) else head_dims + if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): + raise ValueError( + f"Make sure slice_size {slice_size} is a common divisor of " + f"the number of heads used in cross_attention: {head_dims}" + ) + if slice_size is not None and slice_size > min(head_dims): + raise ValueError( + f"slice_size {slice_size} has to be smaller or equal to " + f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" + ) + + for block in self.down_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_attention_slice(slice_size) + + self.mid_block.set_attention_slice(slice_size) + + for block in self.up_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_attention_slice(slice_size) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for block in self.down_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + for block in self.up_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.config.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) + + +class LinearMultiDim(nn.Linear): + def __init__(self, in_features, out_features=None, second_dim=4, *args, **kwargs): + in_features = [in_features, second_dim, 1] if isinstance(in_features, int) else list(in_features) + if out_features is None: + out_features = in_features + out_features = [out_features, second_dim, 1] if isinstance(out_features, int) else list(out_features) + self.in_features_multidim = in_features + self.out_features_multidim = out_features + super().__init__(np.array(in_features).prod(), np.array(out_features).prod()) + + def forward(self, input_tensor, *args, **kwargs): + shape = input_tensor.shape + n_dim = len(self.in_features_multidim) + input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_features) + output_tensor = super().forward(input_tensor) + output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_features_multidim) + return output_tensor + + +class ResnetBlockFlat(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + time_embedding_norm="default", + use_in_shortcut=None, + second_dim=4, + **kwargs, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + + in_channels = [in_channels, second_dim, 1] if isinstance(in_channels, int) else list(in_channels) + self.in_channels_prod = np.array(in_channels).prod() + self.channels_multidim = in_channels + + if out_channels is not None: + out_channels = [out_channels, second_dim, 1] if isinstance(out_channels, int) else list(out_channels) + out_channels_prod = np.array(out_channels).prod() + self.out_channels_multidim = out_channels + else: + out_channels_prod = self.in_channels_prod + self.out_channels_multidim = self.channels_multidim + self.time_embedding_norm = time_embedding_norm + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=self.in_channels_prod, eps=eps, affine=True) + self.conv1 = torch.nn.Conv2d(self.in_channels_prod, out_channels_prod, kernel_size=1, padding=0) + + if temb_channels is not None: + self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels_prod) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels_prod, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels_prod, out_channels_prod, kernel_size=1, padding=0) + + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = ( + self.in_channels_prod != out_channels_prod if use_in_shortcut is None else use_in_shortcut + ) + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + self.in_channels_prod, out_channels_prod, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, input_tensor, temb): + shape = input_tensor.shape + n_dim = len(self.channels_multidim) + input_tensor = input_tensor.reshape(*shape[0:-n_dim], self.in_channels_prod, 1, 1) + input_tensor = input_tensor.view(-1, self.in_channels_prod, 1, 1) + + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + output_tensor = output_tensor.view(*shape[0:-n_dim], -1) + output_tensor = output_tensor.view(*shape[0:-n_dim], *self.out_channels_multidim) + + return output_tensor + + +# Copied from diffusers.models.unet_2d_blocks.DownBlock2D with DownBlock2D->DownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim +class DownBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + LinearMultiDim( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +# Copied from diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D with CrossAttnDownBlock2D->CrossAttnDownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim +class CrossAttnDownBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + attention_type="default", + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.attention_type = attention_type + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + LinearMultiDim( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def set_attention_slice(self, slice_size): + head_dims = self.attn_num_head_channels + head_dims = [head_dims] if isinstance(head_dims, int) else head_dims + if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): + raise ValueError( + f"Make sure slice_size {slice_size} is a common divisor of " + f"the number of heads used in cross_attention: {head_dims}" + ) + if slice_size is not None and slice_size > min(head_dims): + raise ValueError( + f"slice_size {slice_size} has to be smaller or equal to " + f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" + ) + + for attn in self.attentions: + attn._set_attention_slice(slice_size) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for attn in self.attentions: + attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +# Copied from diffusers.models.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim +class UpBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlockFlat( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +# Copied from diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim +class CrossAttnUpBlockFlat(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + attention_type="default", + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.attention_type = attention_type + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlockFlat( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([LinearMultiDim(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def set_attention_slice(self, slice_size): + head_dims = self.attn_num_head_channels + head_dims = [head_dims] if isinstance(head_dims, int) else head_dims + if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): + raise ValueError( + f"Make sure slice_size {slice_size} is a common divisor of " + f"the number of heads used in cross_attention: {head_dims}" + ) + if slice_size is not None and slice_size > min(head_dims): + raise ValueError( + f"slice_size {slice_size} has to be smaller or equal to " + f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" + ) + + for attn in self.attentions: + attn._set_attention_slice(slice_size) + + self.gradient_checkpointing = False + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for attn in self.attentions: + attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat +class UNetMidBlockFlatCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + attention_type="default", + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + ): + super().__init__() + + self.attention_type = attention_type + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlockFlat( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def set_attention_slice(self, slice_size): + head_dims = self.attn_num_head_channels + head_dims = [head_dims] if isinstance(head_dims, int) else head_dims + if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): + raise ValueError( + f"Make sure slice_size {slice_size} is a common divisor of " + f"the number of heads used in cross_attention: {head_dims}" + ) + if slice_size is not None and slice_size > min(head_dims): + raise ValueError( + f"slice_size {slice_size} has to be smaller or equal to " + f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" + ) + + for attn in self.attentions: + attn._set_attention_slice(slice_size) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for attn in self.attentions: + attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states).sample + hidden_states = resnet(hidden_states, temb) + + return hidden_states diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py new file mode 100644 index 0000000000..7be7f4d3ae --- /dev/null +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py @@ -0,0 +1,463 @@ +import inspect +from typing import Callable, List, Optional, Union + +import torch + +import PIL.Image +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import logging +from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline +from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline +from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionMegaSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + tokenizer: CLIPTokenizer + image_feature_extractor: CLIPFeatureExtractor + text_encoder: CLIPTextModel + image_encoder: CLIPVisionModel + image_unet: UNet2DConditionModel + text_unet: UNet2DConditionModel + vae: AutoencoderKL + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + + def __init__( + self, + tokenizer: CLIPTokenizer, + image_feature_extractor: CLIPFeatureExtractor, + text_encoder: CLIPTextModel, + image_encoder: CLIPVisionModel, + image_unet: UNet2DConditionModel, + text_unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + image_feature_extractor=image_feature_extractor, + text_encoder=text_encoder, + image_encoder=image_encoder, + image_unet=image_unet, + text_unet=text_unet, + vae=vae, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.image_unet.config.attention_head_dim // 2 + self.image_unet.set_attention_slice(slice_size) + self.text_unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + @torch.no_grad() + def image_variation( + self, + image: Union[torch.FloatTensor, PIL.Image.Image], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`): + The image prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + + >>> pipe = VersatileDiffusionPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe.image_variation(image, generator=generator).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + expected_components = inspect.signature(VersatileDiffusionImageVariationPipeline.__init__).parameters.keys() + components = {name: component for name, component in self.components.items() if name in expected_components} + return VersatileDiffusionImageVariationPipeline(**components)( + image=image, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + ) + + @torch.no_grad() + def text_to_image( + self, + prompt: Union[str, List[str]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionPipeline + >>> import torch + + >>> pipe = VersatileDiffusionPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe.text_to_image("an astronaut riding on a horse on mars", generator=generator).images[0] + >>> image.save("./astronaut.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + expected_components = inspect.signature(VersatileDiffusionTextToImagePipeline.__init__).parameters.keys() + components = {name: component for name, component in self.components.items() if name in expected_components} + temp_pipeline = VersatileDiffusionTextToImagePipeline(**components) + output = temp_pipeline( + prompt=prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + ) + # swap the attention blocks back to the original state + temp_pipeline._swap_unet_attention_blocks() + + return output + + @torch.no_grad() + def dual_guided( + self, + prompt: Union[PIL.Image.Image, List[PIL.Image.Image]], + image: Union[str, List[str]], + text_to_image_strength: float = 0.5, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + >>> text = "a red car in the sun" + + >>> pipe = VersatileDiffusionPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> text_to_image_strength = 0.75 + + >>> image = pipe.dual_guided( + ... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator + ... ).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.ImagePipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images. + """ + + expected_components = inspect.signature(VersatileDiffusionDualGuidedPipeline.__init__).parameters.keys() + components = {name: component for name, component in self.components.items() if name in expected_components} + temp_pipeline = VersatileDiffusionDualGuidedPipeline(**components) + output = temp_pipeline( + prompt=prompt, + image=image, + text_to_image_strength=text_to_image_strength, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + ) + temp_pipeline._revert_dual_attention() + + return output diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py new file mode 100644 index 0000000000..fa1754a4f0 --- /dev/null +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -0,0 +1,641 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint + +import PIL +from transformers import ( + CLIPFeatureExtractor, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention import DualTransformer2DModel, Transformer2DModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import is_accelerate_available, logging +from .modeling_text_unet import UNetFlatConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) Model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. + tokenizer (`transformers.BertTokenizer`): + Tokenizer of class + [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + tokenizer: CLIPTokenizer + image_feature_extractor: CLIPFeatureExtractor + text_encoder: CLIPTextModelWithProjection + image_encoder: CLIPVisionModelWithProjection + image_unet: UNet2DConditionModel + text_unet: UNetFlatConditionModel + vae: AutoencoderKL + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + + _optional_components = ["text_unet"] + + def __init__( + self, + tokenizer: CLIPTokenizer, + image_feature_extractor: CLIPFeatureExtractor, + text_encoder: CLIPTextModelWithProjection, + image_encoder: CLIPVisionModelWithProjection, + image_unet: UNet2DConditionModel, + text_unet: UNetFlatConditionModel, + vae: AutoencoderKL, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + ): + super().__init__() + self.register_modules( + tokenizer=tokenizer, + image_feature_extractor=image_feature_extractor, + text_encoder=text_encoder, + image_encoder=image_encoder, + image_unet=image_unet, + text_unet=text_unet, + vae=vae, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + if self.text_unet is not None and ( + "dual_cross_attention" not in self.image_unet.config or not self.image_unet.config.dual_cross_attention + ): + # if loading from a universal checkpoint rather than a saved dual-guided pipeline + self._convert_to_dual_attention() + + def remove_unused_weights(self): + self.register_modules(text_unet=None) + + def _convert_to_dual_attention(self): + """ + Replace image_unet's `Transformer2DModel` blocks with `DualTransformer2DModel` that contains transformer blocks + from both `image_unet` and `text_unet` + """ + for name, module in self.image_unet.named_modules(): + if isinstance(module, Transformer2DModel): + parent_name, index = name.rsplit(".", 1) + index = int(index) + + image_transformer = self.image_unet.get_submodule(parent_name)[index] + text_transformer = self.text_unet.get_submodule(parent_name)[index] + + config = image_transformer.config + dual_transformer = DualTransformer2DModel( + num_attention_heads=config.num_attention_heads, + attention_head_dim=config.attention_head_dim, + in_channels=config.in_channels, + num_layers=config.num_layers, + dropout=config.dropout, + norm_num_groups=config.norm_num_groups, + cross_attention_dim=config.cross_attention_dim, + attention_bias=config.attention_bias, + sample_size=config.sample_size, + num_vector_embeds=config.num_vector_embeds, + activation_fn=config.activation_fn, + num_embeds_ada_norm=config.num_embeds_ada_norm, + ) + dual_transformer.transformers[0] = image_transformer + dual_transformer.transformers[1] = text_transformer + + self.image_unet.get_submodule(parent_name)[index] = dual_transformer + self.image_unet.register_to_config(dual_cross_attention=True) + + def _revert_dual_attention(self): + """ + Revert the image_unet `DualTransformer2DModel` blocks back to `Transformer2DModel` with image_unet weights Call + this function if you reuse `image_unet` in another pipeline, e.g. `VersatileDiffusionPipeline` + """ + for name, module in self.image_unet.named_modules(): + if isinstance(module, DualTransformer2DModel): + parent_name, index = name.rsplit(".", 1) + index = int(index) + self.image_unet.get_submodule(parent_name)[index] = module.transformers[0] + + self.image_unet.register_to_config(dual_cross_attention=False) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.image_unet.set_use_memory_efficient_attention_xformers(True) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention with unet->image_unet + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.image_unet.set_use_memory_efficient_attention_xformers(False) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + if isinstance(self.image_unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.image_unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.image_unet.config.attention_head_dim) + + self.image_unet.set_attention_slice(slice_size) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"): + return self.device + for module in self.image_unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_text_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + """ + + def normalize_embeddings(encoder_output): + embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) + embeds_pooled = encoder_output.text_embeds + embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) + return embeds + + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = normalize_embeddings(text_embeddings) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = normalize_embeddings(uncond_embeddings) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + def _encode_image_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + """ + + def normalize_embeddings(encoder_output): + embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state) + embeds = self.image_encoder.visual_projection(embeds) + embeds_pooled = embeds[:, 0:1] + embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) + return embeds + + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + image_input = self.image_feature_extractor(images=prompt, return_tensors="pt") + pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype) + image_embeddings = self.image_encoder(pixel_values) + image_embeddings = normalize_embeddings(image_embeddings) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size + uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") + pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype) + uncond_embeddings = self.image_encoder(pixel_values) + uncond_embeddings = normalize_embeddings(uncond_embeddings) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and conditional embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([uncond_embeddings, image_embeddings]) + + return image_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, prompt, image, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, PIL.Image.Image) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` `PIL.Image` or `list` but is {type(prompt)}") + if not isinstance(image, str) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list): + raise ValueError(f"`image` has to be of type `str` `PIL.Image` or `list` but is {type(image)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def set_transformer_params(self, mix_ratio: float = 0.5, condition_types: Tuple = ("text", "image")): + for name, module in self.image_unet.named_modules(): + if isinstance(module, DualTransformer2DModel): + module.mix_ratio = mix_ratio + + for i, type in enumerate(condition_types): + if type == "text": + module.condition_lengths[i] = self.text_encoder.config.max_position_embeddings + module.transformer_index_for_condition[i] = 1 # use the second (text) transformer + else: + module.condition_lengths[i] = 257 + module.transformer_index_for_condition[i] = 0 # use the first (image) transformer + + @torch.no_grad() + def __call__( + self, + prompt: Union[PIL.Image.Image, List[PIL.Image.Image]], + image: Union[str, List[str]], + text_to_image_strength: float = 0.5, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionDualGuidedPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + >>> text = "a red car in the sun" + + >>> pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe.remove_unused_weights() + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> text_to_image_strength = 0.75 + + >>> image = pipe( + ... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator + ... ).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.ImagePipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When + returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.image_unet.config.sample_size * self.vae_scale_factor + width = width or self.image_unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, image, height, width, callback_steps) + + # 2. Define call parameters + prompt = [prompt] if not isinstance(prompt, list) else prompt + image = [image] if not isinstance(image, list) else image + batch_size = len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompts + text_embeddings = self._encode_text_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance) + image_embeddings = self._encode_image_prompt(image, device, num_images_per_prompt, do_classifier_free_guidance) + dual_prompt_embeddings = torch.cat([text_embeddings, image_embeddings], dim=1) + prompt_types = ("text", "image") + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.image_unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + dual_prompt_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Combine the attention blocks of the image and text UNets + self.set_transformer_params(text_to_image_strength, prompt_types) + + # 8. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=dual_prompt_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 9. Post-processing + image = self.decode_latents(latents) + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py new file mode 100644 index 0000000000..3e51ce6371 --- /dev/null +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -0,0 +1,471 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import torch +import torch.utils.checkpoint + +import PIL +from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import is_accelerate_available, logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) Model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. + tokenizer (`transformers.BertTokenizer`): + Tokenizer of class + [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + image_feature_extractor: CLIPFeatureExtractor + image_encoder: CLIPVisionModelWithProjection + image_unet: UNet2DConditionModel + vae: AutoencoderKL + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + + def __init__( + self, + image_feature_extractor: CLIPFeatureExtractor, + image_encoder: CLIPVisionModelWithProjection, + image_unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + ): + super().__init__() + self.register_modules( + image_feature_extractor=image_feature_extractor, + image_encoder=image_encoder, + image_unet=image_unet, + vae=vae, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.image_unet.set_use_memory_efficient_attention_xformers(True) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention with unet->image_unet + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.image_unet.set_use_memory_efficient_attention_xformers(False) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + if isinstance(self.image_unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.image_unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.image_unet.config.attention_head_dim) + + self.image_unet.set_attention_slice(slice_size) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"): + return self.device + for module in self.image_unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + + def normalize_embeddings(encoder_output): + embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state) + embeds = self.image_encoder.visual_projection(embeds) + embeds_pooled = embeds[:, 0:1] + embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) + return embeds + + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + image_input = self.image_feature_extractor(images=prompt, return_tensors="pt") + pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype) + image_embeddings = self.image_encoder(pixel_values) + image_embeddings = normalize_embeddings(image_embeddings) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_images: List[str] + if negative_prompt is None: + uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, PIL.Image.Image): + uncond_images = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_images = negative_prompt + + uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt") + pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype) + uncond_embeddings = self.image_encoder(pixel_values) + uncond_embeddings = normalize_embeddings(uncond_embeddings) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and conditional embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([uncond_embeddings, image_embeddings]) + + return image_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs(self, image, height, width, callback_steps): + if not isinstance(image, PIL.Image.Image) and not isinstance(image, torch.Tensor): + raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`): + The image prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionImageVariationPipeline + >>> import torch + >>> import requests + >>> from io import BytesIO + >>> from PIL import Image + + >>> # let's download an initial image + >>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg" + + >>> response = requests.get(url) + >>> image = Image.open(BytesIO(response.content)).convert("RGB") + + >>> pipe = VersatileDiffusionImageVariationPipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe(image, generator=generator).images[0] + >>> image.save("./car_variation.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.image_unet.config.sample_size * self.vae_scale_factor + width = width or self.image_unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(image, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(image, PIL.Image.Image) else len(image) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + image_embeddings = self._encode_prompt( + image, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.image_unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py new file mode 100644 index 0000000000..e77f5a2f22 --- /dev/null +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -0,0 +1,525 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import torch +import torch.utils.checkpoint + +from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIPTokenizer + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention import Transformer2DModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import is_accelerate_available, logging +from .modeling_text_unet import UNetFlatConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + vqvae ([`VQModel`]): + Vector-quantized (VQ) Model to encode and decode images to and from latent representations. + bert ([`LDMBertModel`]): + Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. + tokenizer (`transformers.BertTokenizer`): + Tokenizer of class + [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + tokenizer: CLIPTokenizer + image_feature_extractor: CLIPFeatureExtractor + text_encoder: CLIPTextModelWithProjection + image_unet: UNet2DConditionModel + text_unet: UNetFlatConditionModel + vae: AutoencoderKL + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + + _optional_components = ["text_unet"] + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + image_unet: UNet2DConditionModel, + text_unet: UNetFlatConditionModel, + vae: AutoencoderKL, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + ): + super().__init__() + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + image_unet=image_unet, + text_unet=text_unet, + vae=vae, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + if self.text_unet is not None: + self._swap_unet_attention_blocks() + + def _swap_unet_attention_blocks(self): + """ + Swap the `Transformer2DModel` blocks between the image and text UNets + """ + for name, module in self.image_unet.named_modules(): + if isinstance(module, Transformer2DModel): + parent_name, index = name.rsplit(".", 1) + index = int(index) + self.image_unet.get_submodule(parent_name)[index], self.text_unet.get_submodule(parent_name)[index] = ( + self.text_unet.get_submodule(parent_name)[index], + self.image_unet.get_submodule(parent_name)[index], + ) + + def remove_unused_weights(self): + self.register_modules(text_unet=None) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.image_unet.set_use_memory_efficient_attention_xformers(True) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention with unet->image_unet + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.image_unet.set_use_memory_efficient_attention_xformers(False) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + if isinstance(self.image_unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.image_unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.image_unet.config.attention_head_dim) + + self.image_unet.set_attention_slice(slice_size) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"): + return self.device + for module in self.image_unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + """ + + def normalize_embeddings(encoder_output): + embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) + embeds_pooled = encoder_output.text_embeds + embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) + return embeds + + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids + + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + text_embeddings = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + text_embeddings = normalize_embeddings(text_embeddings) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = text_embeddings.shape + text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) + text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + uncond_embeddings = normalize_embeddings(uncond_embeddings) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs(self, prompt, height, width, callback_steps): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if latents is None: + if device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) + else: + latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + + ```py + >>> from diffusers import VersatileDiffusionTextToImagePipeline + >>> import torch + + >>> pipe = VersatileDiffusionTextToImagePipeline.from_pretrained( + ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 + ... ) + >>> pipe.remove_unused_weights() + >>> pipe = pipe.to("cuda") + + >>> generator = torch.Generator(device="cuda").manual_seed(0) + >>> image = pipe("an astronaut riding on a horse on mars", generator=generator).images[0] + >>> image.save("./astronaut.png") + ``` + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.image_unet.config.sample_size * self.vae_scale_factor + width = width or self.image_unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.image_unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 9. Post-processing + image = self.decode_latents(latents) + + # 10. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/vq_diffusion/__init__.py b/src/diffusers/pipelines/vq_diffusion/__init__.py index edf6f570f5..8c9f14f000 100644 --- a/src/diffusers/pipelines/vq_diffusion/__init__.py +++ b/src/diffusers/pipelines/vq_diffusion/__init__.py @@ -1 +1,5 @@ -from .pipeline_vq_diffusion import VQDiffusionPipeline +from ...utils import is_torch_available, is_transformers_available + + +if is_transformers_available() and is_torch_available(): + from .pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings, VQDiffusionPipeline diff --git a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py index 6e5325ba7e..333599d7ec 100644 --- a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py +++ b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py @@ -20,6 +20,8 @@ from diffusers import Transformer2DModel, VQModel from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler from transformers import CLIPTextModel, CLIPTokenizer +from ...configuration_utils import ConfigMixin, register_to_config +from ...modeling_utils import ModelMixin from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...utils import logging @@ -27,6 +29,28 @@ from ...utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin): + """ + Utility class for storing learned text embeddings for classifier free sampling + """ + + @register_to_config + def __init__(self, learnable: bool, hidden_size: Optional[int] = None, length: Optional[int] = None): + super().__init__() + + self.learnable = learnable + + if self.learnable: + assert hidden_size is not None, "learnable=True requires `hidden_size` to be set" + assert length is not None, "learnable=True requires `length` to be set" + + embeddings = torch.zeros(length, hidden_size) + else: + embeddings = None + + self.embeddings = torch.nn.Parameter(embeddings) + + class VQDiffusionPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using VQ Diffusion @@ -55,6 +79,7 @@ class VQDiffusionPipeline(DiffusionPipeline): text_encoder: CLIPTextModel tokenizer: CLIPTokenizer transformer: Transformer2DModel + learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings scheduler: VQDiffusionScheduler def __init__( @@ -64,6 +89,7 @@ class VQDiffusionPipeline(DiffusionPipeline): tokenizer: CLIPTokenizer, transformer: Transformer2DModel, scheduler: VQDiffusionScheduler, + learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings, ): super().__init__() @@ -73,13 +99,78 @@ class VQDiffusionPipeline(DiffusionPipeline): text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, + learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings, ) + def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + + # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion. + # While CLIP does normalize the pooled output of the text transformer when combining + # the image and text embeddings, CLIP does not directly normalize the last hidden state. + # + # CLIP normalizing the pooled output. + # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053 + text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True) + + # duplicate text embeddings for each generation per prompt + text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + if self.learned_classifier_free_sampling_embeddings.learnable: + uncond_embeddings = self.learned_classifier_free_sampling_embeddings.embeddings + uncond_embeddings = uncond_embeddings.unsqueeze(0).repeat(batch_size, 1, 1) + else: + uncond_tokens = [""] * batch_size + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + # See comment for normalizing text embeddings + uncond_embeddings = uncond_embeddings / uncond_embeddings.norm(dim=-1, keepdim=True) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], num_inference_steps: int = 100, + guidance_scale: float = 5.0, truncation_rate: float = 1.0, num_images_per_prompt: int = 1, generator: Optional[torch.Generator] = None, @@ -98,6 +189,12 @@ class VQDiffusionPipeline(DiffusionPipeline): num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)): Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above @@ -137,6 +234,10 @@ class VQDiffusionPipeline(DiffusionPipeline): batch_size = batch_size * num_images_per_prompt + do_classifier_free_guidance = guidance_scale > 1.0 + + text_embeddings = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance) + if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): @@ -145,35 +246,6 @@ class VQDiffusionPipeline(DiffusionPipeline): f" {type(callback_steps)}." ) - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - - if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] - - # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion. - # While CLIP does normalize the pooled output of the text transformer when combining - # the image and text embeddings, CLIP does not directly normalize the last hidden state. - # - # CLIP normalizing the pooled output. - # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053 - text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True) - - # duplicate text embeddings for each generation per prompt - text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) - # get the initial completely masked latents unless the user supplied it latents_shape = (batch_size, self.transformer.num_latent_pixels) @@ -198,9 +270,19 @@ class VQDiffusionPipeline(DiffusionPipeline): sample = latents for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the sample if we are doing classifier free guidance + latent_model_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample + # predict the un-noised image # model_output == `log_p_x_0` - model_output = self.transformer(sample, encoder_hidden_states=text_embeddings, timestep=t).sample + model_output = self.transformer( + latent_model_input, encoder_hidden_states=text_embeddings, timestep=t + ).sample + + if do_classifier_free_guidance: + model_output_uncond, model_output_text = model_output.chunk(2) + model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond) + model_output -= torch.logsumexp(model_output, dim=1, keepdim=True) model_output = self.truncate(model_output, truncation_rate) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 6217bfcd69..d708963839 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -22,6 +22,7 @@ if is_torch_available(): from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from .scheduling_euler_discrete import EulerDiscreteScheduler + from .scheduling_heun import HeunDiscreteScheduler from .scheduling_ipndm import IPNDMScheduler from .scheduling_karras_ve import KarrasVeScheduler from .scheduling_pndm import PNDMScheduler diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 75cef635d0..a2e571f998 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -23,7 +23,7 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate from .scheduling_utils import SchedulerMixin @@ -82,8 +82,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2010.02502 @@ -106,17 +106,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. + prediction_type (`str`, default `epsilon`): + indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`. + `v-prediction` is not supported for this scheduler. """ - _compatible_classes = [ - "PNDMScheduler", - "DDPMScheduler", - "LMSDiscreteScheduler", - "EulerDiscreteScheduler", - "EulerAncestralDiscreteScheduler", - "DPMSolverMultistepScheduler", - ] + _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _deprecated_kwargs = ["predict_epsilon"] + order = 1 @register_to_config def __init__( @@ -129,7 +127,17 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, + prediction_type: str = "epsilon", + **kwargs, ): + message = ( + "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" + " DDIMScheduler.from_pretrained(, prediction_type='epsilon')`." + ) + predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) + if predict_epsilon is not None: + self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") + if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) elif beta_schedule == "linear": @@ -265,7 +273,19 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + # predict V + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) # 4. Clip "predicted x_0" if self.config.clip_sample: @@ -336,5 +356,25 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples + def get_velocity( + self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + def __len__(self): return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index 590e3aac2e..f98d977004 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -23,7 +23,13 @@ import flax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left +from ..utils import deprecate +from .scheduling_utils_flax import ( + _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: @@ -79,8 +85,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2010.02502 @@ -103,8 +109,15 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. + prediction_type (`str`, default `epsilon`): + indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`. + `v-prediction` is not supported for this scheduler. + """ + _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _deprecated_kwargs = ["predict_epsilon"] + @property def has_state(self): return True @@ -118,7 +131,17 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): beta_schedule: str = "linear", set_alpha_to_one: bool = True, steps_offset: int = 0, + prediction_type: str = "epsilon", + **kwargs, ): + message = ( + "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" + " FlaxDDIMScheduler.from_pretrained(, prediction_type='epsilon')`." + ) + predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) + if predict_epsilon is not None: + self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") + if beta_schedule == "linear": self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) elif beta_schedule == "scaled_linear": @@ -252,7 +275,19 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + # predict V + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) # 4. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index c3e373d2bd..d1dfa1a44b 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -22,7 +22,7 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config -from ..utils import BaseOutput, deprecate +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate from .scheduling_utils import SchedulerMixin @@ -80,8 +80,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2006.11239 @@ -99,19 +99,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. - predict_epsilon (`bool`): - optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise. - + prediction_type (`str`, default `epsilon`): + indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`. + `v-prediction` is not supported for this scheduler. """ - _compatible_classes = [ - "DDIMScheduler", - "PNDMScheduler", - "LMSDiscreteScheduler", - "EulerDiscreteScheduler", - "EulerAncestralDiscreteScheduler", - "DPMSolverMultistepScheduler", - ] + _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _deprecated_kwargs = ["predict_epsilon"] + order = 1 @register_to_config def __init__( @@ -123,8 +118,17 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): trained_betas: Optional[np.ndarray] = None, variance_type: str = "fixed_small", clip_sample: bool = True, - predict_epsilon: bool = True, + prediction_type: str = "epsilon", + **kwargs, ): + message = ( + "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" + " DDPMScheduler.from_pretrained(, prediction_type='epsilon')`." + ) + predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) + if predict_epsilon is not None: + self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") + if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) elif beta_schedule == "linear": @@ -248,13 +252,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): """ message = ( - "Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler =" - " DDPMScheduler.from_config(, predict_epsilon=True)`." + "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" + " DDPMScheduler.from_pretrained(, prediction_type='epsilon')`." ) predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) - if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon: + if predict_epsilon is not None: new_config = dict(self.config) - new_config["predict_epsilon"] = predict_epsilon + new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample" self._internal_dict = FrozenDict(new_config) t = timestep @@ -272,10 +276,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if self.config.predict_epsilon: + if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - else: + elif self.config.prediction_type == "sample": pred_original_sample = model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` " + " for the DDPMScheduler." + ) # 3. Clip "predicted x_0" if self.config.clip_sample: @@ -337,5 +346,25 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples + def get_velocity( + self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + def __len__(self): return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index f1b04a0417..97b38fd3a1 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -24,7 +24,12 @@ from jax import random from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config from ..utils import deprecate -from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left +from .scheduling_utils_flax import ( + _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: @@ -79,8 +84,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2006.11239 @@ -98,11 +103,14 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. - predict_epsilon (`bool`): - optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise. - + prediction_type (`str`, default `epsilon`): + indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`. + `v-prediction` is not supported for this scheduler. """ + _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _deprecated_kwargs = ["predict_epsilon"] + @property def has_state(self): return True @@ -117,8 +125,17 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): trained_betas: Optional[jnp.ndarray] = None, variance_type: str = "fixed_small", clip_sample: bool = True, - predict_epsilon: bool = True, + prediction_type: str = "epsilon", + **kwargs, ): + message = ( + "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" + " FlaxDDPMScheduler.from_pretrained(, prediction_type='epsilon')`." + ) + predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) + if predict_epsilon is not None: + self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") + if trained_betas is not None: self.betas = jnp.asarray(trained_betas) elif beta_schedule == "linear": @@ -197,7 +214,6 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): timestep: int, sample: jnp.ndarray, key: random.KeyArray, - predict_epsilon: bool = True, return_dict: bool = True, **kwargs, ) -> Union[FlaxDDPMSchedulerOutput, Tuple]: @@ -220,13 +236,13 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): """ message = ( - "Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler =" - " DDPMScheduler.from_config(, predict_epsilon=True)`." + "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" + " FlaxDDPMScheduler.from_pretrained(, prediction_type='epsilon')`." ) predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) - if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon: + if predict_epsilon is not None: new_config = dict(self.config) - new_config["predict_epsilon"] = predict_epsilon + new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample" self._internal_dict = FrozenDict(new_config) t = timestep @@ -244,10 +260,15 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if self.config.predict_epsilon: + if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - else: + elif self.config.prediction_type == "sample": pred_original_sample = model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` " + " for the FlaxDDPMScheduler." + ) # 3. Clip "predicted x_0" if self.config.clip_sample: diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index d166354809..e27b793b7b 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -21,6 +21,7 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, deprecate from .scheduling_utils import SchedulerMixin, SchedulerOutput @@ -71,8 +72,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. @@ -86,10 +87,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): solver_order (`int`, default `2`): the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. - predict_epsilon (`bool`, default `True`): - we currently support both the noise prediction model and the data prediction model. If the model predicts - the noise / epsilon, set `predict_epsilon` to `True`. If the model predicts the data / x0 directly, set - `predict_epsilon` to `False`. + prediction_type (`str`, default `epsilon`): + indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`, + or `v-prediction`. thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to @@ -116,14 +116,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): """ - _compatible_classes = [ - "DDIMScheduler", - "DDPMScheduler", - "PNDMScheduler", - "LMSDiscreteScheduler", - "EulerDiscreteScheduler", - "EulerAncestralDiscreteScheduler", - ] + _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _deprecated_kwargs = ["predict_epsilon"] + order = 1 @register_to_config def __init__( @@ -134,14 +129,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, solver_order: int = 2, - predict_epsilon: bool = True, + prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, algorithm_type: str = "dpmsolver++", solver_type: str = "midpoint", lower_order_final: bool = True, + **kwargs, ): + message = ( + "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" + " DPMSolverMultistepScheduler.from_pretrained(, prediction_type='epsilon')`." + ) + predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) + if predict_epsilon is not None: + self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") + if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) elif beta_schedule == "linear": @@ -209,7 +213,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): """ Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. - DPM-Solver is designed to discretize an integral of the noise prediciton model, and DPM-Solver++ is designed to + DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an integral of the data prediction model. So we need to first convert the model output to the corresponding type to match the algorithm. @@ -227,13 +231,25 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): """ # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type == "dpmsolver++": - if self.config.predict_epsilon: + if self.config.prediction_type == "epsilon": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t - else: + elif self.config.prediction_type == "sample": x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + if self.config.thresholding: # Dynamic thresholding in https://arxiv.org/abs/2205.11487 + orig_dtype = x0_pred.dtype + if orig_dtype not in [torch.float, torch.double]: + x0_pred = x0_pred.float() dynamic_max_val = torch.quantile( torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1 ) @@ -242,15 +258,25 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device), )[(...,) + (None,) * (x0_pred.ndim - 1)] x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val + x0_pred = x0_pred.type(orig_dtype) return x0_pred # DPM-Solver needs to solve an integral of the noise prediction model. elif self.config.algorithm_type == "dpmsolver": - if self.config.predict_epsilon: + if self.config.prediction_type == "epsilon": return model_output - else: + elif self.config.prediction_type == "sample": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = (sample - alpha_t * model_output) / sigma_t return epsilon + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = alpha_t * model_output + sigma_t * sample + return epsilon + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) def dpm_solver_first_order_update( self, diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py index c9a6d1cd5c..78b611ae27 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py @@ -23,7 +23,13 @@ import jax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left +from ..utils import deprecate +from .scheduling_utils_flax import ( + _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray: @@ -96,8 +102,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 @@ -113,10 +119,9 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): solver_order (`int`, default `2`): the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. - predict_epsilon (`bool`, default `True`): - we currently support both the noise prediction model and the data prediction model. If the model predicts - the noise / epsilon, set `predict_epsilon` to `True`. If the model predicts the data / x0 directly, set - `predict_epsilon` to `False`. + prediction_type (`str`, default `epsilon`): + indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`, + or `v-prediction`. thresholding (`bool`, default `False`): whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to @@ -143,6 +148,9 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): """ + _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _deprecated_kwargs = ["predict_epsilon"] + @property def has_state(self): return True @@ -156,14 +164,23 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): beta_schedule: str = "linear", trained_betas: Optional[jnp.ndarray] = None, solver_order: int = 2, - predict_epsilon: bool = True, + prediction_type: str = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, algorithm_type: str = "dpmsolver++", solver_type: str = "midpoint", lower_order_final: bool = True, + **kwargs, ): + message = ( + "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" + " FlaxDPMSolverMultistepScheduler.from_pretrained(, prediction_type='epsilon')`." + ) + predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) + if predict_epsilon is not None: + self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") + if trained_betas is not None: self.betas = jnp.asarray(trained_betas) elif beta_schedule == "linear": @@ -235,7 +252,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): """ Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. - DPM-Solver is designed to discretize an integral of the noise prediciton model, and DPM-Solver++ is designed to + DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an integral of the data prediction model. So we need to first convert the model output to the corresponding type to match the algorithm. @@ -253,11 +270,20 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): """ # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type == "dpmsolver++": - if self.config.predict_epsilon: + if self.config.prediction_type == "epsilon": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t - else: + elif self.config.prediction_type == "sample": x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " + " or `v_prediction` for the FlaxDPMSolverMultistepScheduler." + ) + if self.config.thresholding: # Dynamic thresholding in https://arxiv.org/abs/2205.11487 dynamic_max_val = jnp.percentile( @@ -270,12 +296,21 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): return x0_pred # DPM-Solver needs to solve an integral of the noise prediction model. elif self.config.algorithm_type == "dpmsolver": - if self.config.predict_epsilon: + if self.config.prediction_type == "epsilon": return model_output - else: + elif self.config.prediction_type == "sample": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = (sample - alpha_t * model_output) / sigma_t return epsilon + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = alpha_t * model_output + sigma_t * sample + return epsilon + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " + " or `v_prediction` for the FlaxDPMSolverMultistepScheduler." + ) def dpm_solver_first_order_update( self, model_output: jnp.ndarray, timestep: int, prev_timestep: int, sample: jnp.ndarray diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 621b5c17c0..301ad2cebe 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -19,7 +19,7 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, logging +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging from .scheduling_utils import SchedulerMixin @@ -52,8 +52,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. @@ -67,14 +67,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): """ - _compatible_classes = [ - "DDIMScheduler", - "DDPMScheduler", - "LMSDiscreteScheduler", - "PNDMScheduler", - "EulerDiscreteScheduler", - "DPMSolverMultistepScheduler", - ] + _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + order = 1 @register_to_config def __init__( @@ -196,7 +190,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ) if not self.is_scale_input_called: - logger.warn( + logger.warning( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." ) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 2f9e938474..10b0138abd 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -19,7 +19,7 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, logging +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging from .scheduling_utils import SchedulerMixin @@ -53,8 +53,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. @@ -68,14 +68,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): """ - _compatible_classes = [ - "DDIMScheduler", - "DDPMScheduler", - "LMSDiscreteScheduler", - "PNDMScheduler", - "EulerAncestralDiscreteScheduler", - "DPMSolverMultistepScheduler", - ] + _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + order = 1 @register_to_config def __init__( @@ -85,6 +79,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, + prediction_type: str = "epsilon", ): if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) @@ -205,7 +200,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ) if not self.is_scale_input_called: - logger.warn( + logger.warning( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." ) @@ -236,7 +231,15 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - pred_original_sample = sample - sigma_hat * model_output + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma_hat * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) # 2. Convert to an ODE derivative derivative = (sample - pred_original_sample) / sigma_hat diff --git a/src/diffusers/schedulers/scheduling_heun.py b/src/diffusers/schedulers/scheduling_heun.py new file mode 100644 index 0000000000..e6e5335e0d --- /dev/null +++ b/src/diffusers/schedulers/scheduling_heun.py @@ -0,0 +1,247 @@ +# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Args: + Implements Algorithm 2 (Heun steps) from Karras et al. (2022). for discrete beta schedules. Based on the original + k-diffusion implementation by Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L90 + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the + starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. + """ + + _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + order = 2 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, # sensible defaults + beta_end: float = 0.012, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + ): + if trained_betas is not None: + self.betas = torch.from_numpy(trained_betas) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # set all values + self.set_timesteps(num_train_timesteps, None, num_train_timesteps) + + def index_for_timestep(self, timestep): + indices = (self.timesteps == timestep).nonzero() + if self.state_in_first_order: + pos = 0 if indices.shape[0] < 2 else 1 + else: + pos = 0 + return indices[pos].item() + + def scale_model_input( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + ) -> torch.FloatTensor: + """ + Args: + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep + Returns: + `torch.FloatTensor`: scaled input sample + """ + step_index = self.index_for_timestep(timestep) + + sigma = self.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + return sample + + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device] = None, + num_train_timesteps: Optional[int] = None, + ): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps + + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + sigmas = torch.from_numpy(sigmas).to(device=device) + self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = self.sigmas.max() + + timesteps = torch.from_numpy(timesteps) + timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2), timesteps[-1:]]) + + if str(device).startswith("mps"): + # mps does not support float64 + self.timesteps = timesteps.to(device, dtype=torch.float32) + else: + self.timesteps = timesteps.to(device=device) + + # empty dt and derivative + self.prev_derivative = None + self.dt = None + + @property + def state_in_first_order(self): + return self.dt is None + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: Union[float, torch.FloatTensor], + sample: Union[torch.FloatTensor, np.ndarray], + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Args: + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep + (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + step_index = self.index_for_timestep(timestep) + + if self.state_in_first_order: + sigma = self.sigmas[step_index] + sigma_next = self.sigmas[step_index + 1] + else: + # 2nd order / Heun's method + sigma = self.sigmas[step_index - 1] + sigma_next = self.sigmas[step_index] + + # currently only gamma=0 is supported. This usually works best anyways. + # We can support gamma in the future but then need to scale the timestep before + # passing it to the model which requires a change in API + gamma = 0 + sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + pred_original_sample = sample - sigma_hat * model_output + + if self.state_in_first_order: + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma_hat + # 3. 1st order derivative + dt = sigma_next - sigma_hat + + # store for 2nd order step + self.prev_derivative = derivative + self.dt = dt + self.sample = sample + else: + # 2. 2nd order / Heun's method + derivative = (sample - pred_original_sample) / sigma_hat + derivative = (self.prev_derivative + derivative) / 2 + + # 3. Retrieve 1st order derivative + dt = self.dt + sample = self.sample + + # free dt and derivative + # Note, this puts the scheduler in "first order mode" + self.prev_derivative = None + self.dt = None + self.sample = None + + prev_sample = sample + derivative * dt + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + self.timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + step_indices = [self.index_for_timestep(t) for t in timesteps] + + sigma = self.sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_ipndm.py b/src/diffusers/schedulers/scheduling_ipndm.py index fb413a2805..1bcebe65a3 100644 --- a/src/diffusers/schedulers/scheduling_ipndm.py +++ b/src/diffusers/schedulers/scheduling_ipndm.py @@ -28,8 +28,8 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2202.09778 @@ -37,6 +37,8 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin): num_train_timesteps (`int`): number of diffusion steps used to train the model. """ + order = 1 + @register_to_config def __init__(self, num_train_timesteps: int = 1000): # set `betas`, `alphas`, `timesteps` diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index 743f2e061c..41a73b3ac3 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -56,8 +56,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the @@ -77,6 +77,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): """ + order = 2 + @register_to_config def __init__( self, diff --git a/src/diffusers/schedulers/scheduling_karras_ve_flax.py b/src/diffusers/schedulers/scheduling_karras_ve_flax.py index 78ab007954..c4e612c3cc 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_karras_ve_flax.py @@ -67,8 +67,8 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 373c373ee0..68deae8943 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -21,7 +21,7 @@ import torch from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput from .scheduling_utils import SchedulerMixin @@ -52,8 +52,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. @@ -67,14 +67,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): """ - _compatible_classes = [ - "DDIMScheduler", - "DDPMScheduler", - "PNDMScheduler", - "EulerDiscreteScheduler", - "EulerAncestralDiscreteScheduler", - "DPMSolverMultistepScheduler", - ] + _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + order = 1 @register_to_config def __init__( @@ -250,19 +244,18 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples - self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 - self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: - self.timesteps = self.timesteps.to(original_samples.device) + schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - schedule_timesteps = self.timesteps step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - sigma = self.sigmas[step_indices].flatten() + sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index 20982d38aa..21f25f72fa 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -20,7 +20,12 @@ import jax.numpy as jnp from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left +from .scheduling_utils_flax import ( + _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) @flax.struct.dataclass @@ -49,8 +54,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. @@ -63,6 +68,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. """ + _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + @property def has_state(self): return True diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index eec18af8d3..e2a076925c 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -21,6 +21,7 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS from .scheduling_utils import SchedulerMixin, SchedulerOutput @@ -60,8 +61,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2202.09778 @@ -88,14 +89,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): """ - _compatible_classes = [ - "DDIMScheduler", - "DDPMScheduler", - "LMSDiscreteScheduler", - "EulerDiscreteScheduler", - "EulerAncestralDiscreteScheduler", - "DPMSolverMultistepScheduler", - ] + _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + order = 1 @register_to_config def __init__( diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 357ecfe046..298e62de20 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -23,7 +23,12 @@ import jax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left +from .scheduling_utils_flax import ( + _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray: @@ -87,8 +92,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2202.09778 @@ -114,6 +119,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): stable diffusion. """ + _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + @property def has_state(self): return True diff --git a/src/diffusers/schedulers/scheduling_repaint.py b/src/diffusers/schedulers/scheduling_repaint.py index 1751f41c66..0b80181f43 100644 --- a/src/diffusers/schedulers/scheduling_repaint.py +++ b/src/diffusers/schedulers/scheduling_repaint.py @@ -77,8 +77,8 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/pdf/2201.09865.pdf @@ -102,6 +102,8 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin): """ + order = 1 + @register_to_config def __init__( self, diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index d31adbc3c6..89d3d4a585 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -50,8 +50,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. @@ -66,6 +66,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): correct_steps (`int`): number of correction steps performed on a produced sample. """ + order = 1 + @register_to_config def __init__( self, diff --git a/src/diffusers/schedulers/scheduling_sde_ve_flax.py b/src/diffusers/schedulers/scheduling_sde_ve_flax.py index d3eadede61..d1f762bc90 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_sde_ve_flax.py @@ -64,8 +64,8 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index a37a159a87..5e4fe40229 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -29,8 +29,8 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more information, see the original paper: https://arxiv.org/abs/2011.13456 @@ -38,6 +38,8 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): """ + order = 1 + @register_to_config def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3): self.sigmas = None diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 29bf982f6a..90ab674e38 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib +import os from dataclasses import dataclass +from typing import Any, Dict, Optional, Union import torch @@ -38,6 +41,114 @@ class SchedulerOutput(BaseOutput): class SchedulerMixin: """ Mixin containing common functions for the schedulers. + + Class attributes: + - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that + `from_config` can be used from a class different than the one used to save the config (should be overridden + by parent class). """ config_name = SCHEDULER_CONFIG_NAME + _compatibles = [] + has_compatibles = True + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Dict[str, Any] = None, + subfolder: Optional[str] = None, + return_unused_kwargs=False, + **kwargs, + ): + r""" + Instantiate a Scheduler class from a pre-defined JSON configuration file inside a directory or Hub repo. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an + organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing the schedluer configurations saved using + [`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`. + subfolder (`str`, *optional*): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + Whether kwargs that are not consumed by the Python class should be returned or not. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + + """ + config, kwargs = cls.load_config( + pretrained_model_name_or_path=pretrained_model_name_or_path, + subfolder=subfolder, + return_unused_kwargs=True, + **kwargs, + ) + return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs) + + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the + [`~SchedulerMixin.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file will be saved (will be created if it does not exist). + """ + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + + @property + def compatibles(self): + """ + Returns all schedulers that are compatible with this scheduler + + Returns: + `List[SchedulerMixin]`: List of compatible schedulers + """ + return self._get_compatibles() + + @classmethod + def _get_compatibles(cls): + compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) + diffusers_library = importlib.import_module(__name__.split(".")[0]) + compatible_classes = [ + getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) + ] + return compatible_classes diff --git a/src/diffusers/schedulers/scheduling_utils_flax.py b/src/diffusers/schedulers/scheduling_utils_flax.py index e545cfe247..5dc28c25d9 100644 --- a/src/diffusers/schedulers/scheduling_utils_flax.py +++ b/src/diffusers/schedulers/scheduling_utils_flax.py @@ -11,15 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib +import os from dataclasses import dataclass -from typing import Tuple +from typing import Any, Dict, Optional, Tuple, Union import jax.numpy as jnp -from ..utils import BaseOutput +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput SCHEDULER_CONFIG_NAME = "scheduler_config.json" +_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = ["Flax" + c for c in _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS] @dataclass @@ -39,9 +42,126 @@ class FlaxSchedulerOutput(BaseOutput): class FlaxSchedulerMixin: """ Mixin containing common functions for the schedulers. + + Class attributes: + - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that + `from_config` can be used from a class different than the one used to save the config (should be overridden + by parent class). """ config_name = SCHEDULER_CONFIG_NAME + _compatibles = [] + has_compatibles = True + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Dict[str, Any] = None, + subfolder: Optional[str] = None, + return_unused_kwargs=False, + **kwargs, + ): + r""" + Instantiate a Scheduler class from a pre-defined JSON-file. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an + organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~SchedulerMixin.save_pretrained`], + e.g., `./my_model_directory/`. + subfolder (`str`, *optional*): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + Whether kwargs that are not consumed by the Python class should be returned or not. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + + """ + config, kwargs = cls.load_config( + pretrained_model_name_or_path=pretrained_model_name_or_path, + subfolder=subfolder, + return_unused_kwargs=True, + **kwargs, + ) + scheduler, unused_kwargs = cls.from_config(config, return_unused_kwargs=True, **kwargs) + + if hasattr(scheduler, "create_state") and getattr(scheduler, "has_state", False): + state = scheduler.create_state() + + if return_unused_kwargs: + return scheduler, state, unused_kwargs + + return scheduler, state + + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the + [`~FlaxSchedulerMixin.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file will be saved (will be created if it does not exist). + """ + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + + @property + def compatibles(self): + """ + Returns all schedulers that are compatible with this scheduler + + Returns: + `List[SchedulerMixin]`: List of compatible schedulers + """ + return self._get_compatibles() + + @classmethod + def _get_compatibles(cls): + compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) + diffusers_library = importlib.import_module(__name__.split(".")[0]) + compatible_classes = [ + getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) + ] + return compatible_classes def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray: diff --git a/src/diffusers/schedulers/scheduling_vq_diffusion.py b/src/diffusers/schedulers/scheduling_vq_diffusion.py index dbe320d998..89ba722a18 100644 --- a/src/diffusers/schedulers/scheduling_vq_diffusion.py +++ b/src/diffusers/schedulers/scheduling_vq_diffusion.py @@ -112,8 +112,8 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2111.14822 @@ -138,6 +138,8 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): The ending cumulative gamma value. """ + order = 1 + @register_to_config def __init__( self, diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index a00e1f4dcd..1c2e2c9abb 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -28,16 +28,19 @@ from .import_utils import ( is_inflect_available, is_modelcards_available, is_onnx_available, + is_safetensors_available, is_scipy_available, is_tf_available, is_torch_available, is_torch_version, is_transformers_available, + is_transformers_version, is_unidecode_available, requires_backends, ) from .logging import get_logger from .outputs import BaseOutput +from .pil_utils import PIL_INTERPOLATION if is_torch_available(): @@ -67,8 +70,20 @@ CONFIG_NAME = "config.json" WEIGHTS_NAME = "diffusion_pytorch_model.bin" FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" ONNX_WEIGHTS_NAME = "model.onnx" +SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb" HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" DIFFUSERS_CACHE = default_cache_path DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) + +_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [ + "DDIMScheduler", + "DDPMScheduler", + "PNDMScheduler", + "LMSDiscreteScheduler", + "EulerDiscreteScheduler", + "HeunDiscreteScheduler", + "EulerAncestralDiscreteScheduler", + "DPMSolverMultistepScheduler", +] diff --git a/src/diffusers/utils/deprecation_utils.py b/src/diffusers/utils/deprecation_utils.py index eac4303157..7c8bfc901b 100644 --- a/src/diffusers/utils/deprecation_utils.py +++ b/src/diffusers/utils/deprecation_utils.py @@ -32,7 +32,7 @@ def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn if warning is not None: warning = warning + " " if standard_warn else "" - warnings.warn(warning + message, DeprecationWarning) + warnings.warn(warning + message, FutureWarning) if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: call_frame = inspect.getouterframes(inspect.currentframe())[1] diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index af2e0c7c61..9846927cb1 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -362,6 +362,21 @@ class EulerDiscreteScheduler(metaclass=DummyObject): requires_backends(cls, ["torch"]) +class HeunDiscreteScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class IPNDMScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py index 221020030e..ae9412a956 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py @@ -34,6 +34,21 @@ class OnnxStableDiffusionInpaintPipeline(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers", "onnx"]) +class OnnxStableDiffusionInpaintPipelineLegacy(metaclass=DummyObject): + _backends = ["torch", "transformers", "onnx"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "onnx"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "onnx"]) + + class OnnxStableDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "onnx"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 63e8a60f74..2d932d2405 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -4,6 +4,36 @@ from ..utils import DummyObject, requires_backends +class AltDiffusionImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AltDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class CycleDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -34,6 +64,21 @@ class LDMTextToImagePipeline(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionImageVariationPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -94,6 +139,96 @@ class StableDiffusionPipeline(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionPipelineSafe(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionUpscalePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class VersatileDiffusionDualGuidedPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class VersatileDiffusionImageVariationPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class VersatileDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class VersatileDiffusionTextToImagePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class VQDiffusionPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 005cbb6170..531f9eab2f 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -42,6 +42,7 @@ ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) USE_TF = os.environ.get("USE_TF", "AUTO").upper() USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() +USE_SAFETENSORS = os.environ.get("USE_SAFETENSORS", "AUTO").upper() STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} @@ -55,7 +56,7 @@ if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VA except importlib_metadata.PackageNotFoundError: _torch_available = False else: - logger.info("Disabling PyTorch because USE_TF is set") + logger.info("Disabling PyTorch because USE_TORCH is set") _torch_available = False @@ -109,6 +110,17 @@ if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: else: _flax_available = False +if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES: + _safetensors_available = importlib.util.find_spec("safetensors") is not None + if _safetensors_available: + try: + _safetensors_version = importlib_metadata.version("safetensors") + logger.info(f"Safetensors version {_safetensors_version} available.") + except importlib_metadata.PackageNotFoundError: + _safetensors_available = False +else: + logger.info("Disabling Safetensors because USE_TF is set") + _safetensors_available = False _transformers_available = importlib.util.find_spec("transformers") is not None try: @@ -145,7 +157,13 @@ except importlib_metadata.PackageNotFoundError: _onnxruntime_version = "N/A" _onnx_available = importlib.util.find_spec("onnxruntime") is not None if _onnx_available: - candidates = ("onnxruntime", "onnxruntime-gpu", "onnxruntime-directml", "onnxruntime-openvino") + candidates = ( + "onnxruntime", + "onnxruntime-gpu", + "onnxruntime-directml", + "onnxruntime-openvino", + "ort_nightly_directml", + ) _onnxruntime_version = None # For the metadata, we have to look for both onnxruntime and onnxruntime-gpu for pkg in candidates: @@ -190,6 +208,10 @@ def is_torch_available(): return _torch_available +def is_safetensors_available(): + return _safetensors_available + + def is_tf_available(): return _tf_available @@ -303,6 +325,17 @@ def requires_backends(obj, backends): if failed: raise ImportError("".join(failed)) + if name in [ + "VersatileDiffusionTextToImagePipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionDualGuidedPipeline", + "StableDiffusionImageVariationPipeline", + ] and is_transformers_version("<", "4.25.0.dev0"): + raise ImportError( + f"You need to install `transformers` from 'main' in order to use {name}: \n```\n pip install" + " git+https://github.com/huggingface/transformers \n```" + ) + class DummyObject(type): """ @@ -347,3 +380,17 @@ def is_torch_version(operation: str, version: str): A string version of PyTorch """ return compare_versions(parse(_torch_version), operation, version) + + +def is_transformers_version(operation: str, version: str): + """ + Args: + Compares the current Transformers version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A string version of PyTorch + """ + if not _transformers_available: + return False + return compare_versions(parse(_transformers_version), operation, version) diff --git a/src/diffusers/utils/pil_utils.py b/src/diffusers/utils/pil_utils.py new file mode 100644 index 0000000000..39d0a15a4e --- /dev/null +++ b/src/diffusers/utils/pil_utils.py @@ -0,0 +1,21 @@ +import PIL.Image +import PIL.ImageOps +from packaging import version + + +if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } +else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } diff --git a/tests/fixtures/custom_pipeline/what_ever.py b/tests/fixtures/custom_pipeline/what_ever.py new file mode 100644 index 0000000000..e7429d0a19 --- /dev/null +++ b/tests/fixtures/custom_pipeline/what_ever.py @@ -0,0 +1,101 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# limitations under the License. + + +from typing import Optional, Tuple, Union + +import torch + +from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class CustomLocalPipeline(DiffusionPipeline): + r""" + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of + [`DDPMScheduler`], or [`DDIMScheduler`]. + """ + + def __init__(self, unet, scheduler): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + batch_size: int = 1, + generator: Optional[torch.Generator] = None, + num_inference_steps: int = 50, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Args: + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + eta (`float`, *optional*, defaults to 0.0): + The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM). + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + + # Sample gaussian noise to begin loop + image = torch.randn( + (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size), + generator=generator, + ) + image = image.to(self.device) + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + + for t in self.progress_bar(self.scheduler.timesteps): + # 1. predict noise model_output + model_output = self.unet(image, t).sample + + # 2. predict previous mean of image x_t-1 and add variance depending on eta + # eta corresponds to η in paper and should be between [0, 1] + # do x_t -> x_t-1 + image = self.scheduler.step(model_output, t, image).prev_sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,), "This is a local test" + + return ImagePipelineOutput(images=image), "This is a local test" diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py index 202936324a..7f61cbfb03 100644 --- a/tests/models/test_models_unet_1d.py +++ b/tests/models/test_models_unet_1d.py @@ -67,8 +67,8 @@ class UNet1DModelTests(ModelTesterMixin, unittest.TestCase): super().test_from_pretrained_save_pretrained() @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") - def test_model_from_config(self): - super().test_model_from_config() + def test_model_from_pretrained(self): + super().test_model_from_pretrained() @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_output(self): @@ -187,8 +187,8 @@ class UNetRLModelTests(ModelTesterMixin, unittest.TestCase): super().test_from_pretrained_save_pretrained() @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") - def test_model_from_config(self): - super().test_model_from_config() + def test_model_from_pretrained(self): + super().test_model_from_pretrained() @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_output(self): diff --git a/tests/models/test_models_unet_2d.py b/tests/models/test_models_unet_2d.py index 81437311c6..59b9e02ff8 100644 --- a/tests/models/test_models_unet_2d.py +++ b/tests/models/test_models_unet_2d.py @@ -296,6 +296,44 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): for name, param in named_params.items(): self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5)) + def test_model_with_attention_head_dim_tuple(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_model_with_use_linear_projection(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["use_linear_projection"] = True + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): model_class = UNet2DModel @@ -601,3 +639,29 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase): expected_output_slice = torch.tensor(expected_slice) assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) + + @parameterized.expand( + [ + # fmt: off + [83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]], + [17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]], + [8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]], + [3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]], + # fmt: on + ] + ) + @require_torch_gpu + def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice): + model = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True) + latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True) + encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True) + + with torch.no_grad(): + sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + assert sample.shape == latents.shape + + output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() + expected_output_slice = torch.tensor(expected_slice) + + assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) diff --git a/tests/models/test_models_unet_2d_flax.py b/tests/models/test_models_unet_2d_flax.py new file mode 100644 index 0000000000..4b279d2f33 --- /dev/null +++ b/tests/models/test_models_unet_2d_flax.py @@ -0,0 +1,103 @@ +import gc +import unittest + +from diffusers import FlaxUNet2DConditionModel +from diffusers.utils import is_flax_available +from diffusers.utils.testing_utils import load_hf_numpy, require_flax, slow +from parameterized import parameterized + + +if is_flax_available(): + import jax + import jax.numpy as jnp + + +@slow +@require_flax +class FlaxUNet2DConditionModelIntegrationTests(unittest.TestCase): + def get_file_format(self, seed, shape): + return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy" + + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + + def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False): + dtype = jnp.bfloat16 if fp16 else jnp.float32 + image = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype) + return image + + def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"): + dtype = jnp.bfloat16 if fp16 else jnp.float32 + revision = "bf16" if fp16 else None + + model, params = FlaxUNet2DConditionModel.from_pretrained( + model_id, subfolder="unet", dtype=dtype, revision=revision + ) + return model, params + + def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False): + dtype = jnp.bfloat16 if fp16 else jnp.float32 + hidden_states = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype) + return hidden_states + + @parameterized.expand( + [ + # fmt: off + [83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]], + [17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]], + [8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]], + [3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]], + # fmt: on + ] + ) + def test_compvis_sd_v1_4_flax_vs_torch_fp16(self, seed, timestep, expected_slice): + model, params = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True) + latents = self.get_latents(seed, fp16=True) + encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) + + sample = model.apply( + {"params": params}, + latents, + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=encoder_hidden_states, + ).sample + + assert sample.shape == latents.shape + + output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32) + expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32) + + # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, in the same hardware + assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2) + + @parameterized.expand( + [ + # fmt: off + [83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]], + [17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]], + [8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]], + [3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]], + # fmt: on + ] + ) + def test_stabilityai_sd_v2_flax_vs_torch_fp16(self, seed, timestep, expected_slice): + model, params = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True) + latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True) + encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True) + + sample = model.apply( + {"params": params}, + latents, + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=encoder_hidden_states, + ).sample + + assert sample.shape == latents.shape + + output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32) + expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32) + + # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, on the same hardware + assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2) diff --git a/tests/pipelines/altdiffusion/__init__.py b/tests/pipelines/altdiffusion/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/pipelines/altdiffusion/test_alt_diffusion.py b/tests/pipelines/altdiffusion/test_alt_diffusion.py new file mode 100644 index 0000000000..91fe764449 --- /dev/null +++ b/tests/pipelines/altdiffusion/test_alt_diffusion.py @@ -0,0 +1,347 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch + +from diffusers import AltDiffusionPipeline, AutoencoderKL, DDIMScheduler, PNDMScheduler, UNet2DConditionModel +from diffusers.pipelines.alt_diffusion.modeling_roberta_series import ( + RobertaSeriesConfig, + RobertaSeriesModelWithTransformation, +) +from diffusers.utils import floats_tensor, slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu +from transformers import XLMRobertaTokenizer + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class AltDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def dummy_image(self): + batch_size = 1 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) + return image + + @property + def dummy_cond_unet(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + return model + + @property + def dummy_cond_unet_inpaint(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=9, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + return model + + @property + def dummy_vae(self): + torch.manual_seed(0) + model = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + return model + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = RobertaSeriesConfig( + hidden_size=32, + project_dim=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + vocab_size=5002, + ) + return RobertaSeriesModelWithTransformation(config) + + @property + def dummy_extractor(self): + def extract(*args, **kwargs): + class Out: + def __init__(self): + self.pixel_values = torch.ones([0]) + + def to(self, device): + self.pixel_values.to(device) + return self + + return Out() + + return extract + + def test_alt_diffusion_ddim(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = XLMRobertaTokenizer.from_pretrained("hf-internal-testing/tiny-xlm-roberta") + tokenizer.model_max_length = 77 + + # make sure here that pndm scheduler skips prk + alt_pipe = AltDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + alt_pipe = alt_pipe.to(device) + alt_pipe.set_progress_bar_config(disable=None) + + prompt = "A photo of an astronaut" + + generator = torch.Generator(device=device).manual_seed(0) + output = alt_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = alt_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array( + [0.5748162, 0.60447145, 0.48821217, 0.50100636, 0.5431185, 0.45763683, 0.49657696, 0.48132733, 0.47573093] + ) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_alt_diffusion_pndm(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = XLMRobertaTokenizer.from_pretrained("hf-internal-testing/tiny-xlm-roberta") + tokenizer.model_max_length = 77 + + # make sure here that pndm scheduler skips prk + alt_pipe = AltDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + alt_pipe = alt_pipe.to(device) + alt_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = alt_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = alt_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array( + [0.51605093, 0.5707241, 0.47365507, 0.50578886, 0.5633877, 0.4642503, 0.5182081, 0.48763484, 0.49084237] + ) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + def test_alt_diffusion_fp16(self): + """Test that stable diffusion works with fp16""" + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = XLMRobertaTokenizer.from_pretrained("hf-internal-testing/tiny-xlm-roberta") + tokenizer.model_max_length = 77 + + # put models in fp16 + unet = unet.half() + vae = vae.half() + bert = bert.half() + + # make sure here that pndm scheduler skips prk + alt_pipe = AltDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + alt_pipe = alt_pipe.to(torch_device) + alt_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = alt_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images + + assert image.shape == (1, 64, 64, 3) + + +@slow +@require_torch_gpu +class AltDiffusionPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_alt_diffusion(self): + # make sure here that pndm scheduler skips prk + alt_pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion", safety_checker=None) + alt_pipe = alt_pipe.to(torch_device) + alt_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast("cuda"): + output = alt_pipe( + [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np" + ) + + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array( + [0.8720703, 0.87109375, 0.87402344, 0.87109375, 0.8779297, 0.8925781, 0.8823242, 0.8808594, 0.8613281] + ) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_alt_diffusion_fast_ddim(self): + scheduler = DDIMScheduler.from_pretrained("BAAI/AltDiffusion", subfolder="scheduler") + + alt_pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion", scheduler=scheduler, safety_checker=None) + alt_pipe = alt_pipe.to(torch_device) + alt_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + + with torch.autocast("cuda"): + output = alt_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy") + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array( + [0.9267578, 0.9301758, 0.9013672, 0.9345703, 0.92578125, 0.94433594, 0.9423828, 0.9423828, 0.9160156] + ) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_alt_diffusion_text2img_pipeline_fp16(self): + torch.cuda.reset_peak_memory_stats() + model_id = "BAAI/AltDiffusion" + pipe = AltDiffusionPipeline.from_pretrained( + model_id, revision="fp16", torch_dtype=torch.float16, safety_checker=None + ) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "a photograph of an astronaut riding a horse" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output_chunked = pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image_chunked = output_chunked.images + + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + output = pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image = output.images + + # Make sure results are close enough + diff = np.abs(image_chunked.flatten() - image.flatten()) + # They ARE different since ops are not run always at the same precision + # however, they should be extremely close. + assert diff.mean() < 2e-2 diff --git a/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py b/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py new file mode 100644 index 0000000000..0dab14b317 --- /dev/null +++ b/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py @@ -0,0 +1,256 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch + +from diffusers import AltDiffusionImg2ImgPipeline, AutoencoderKL, PNDMScheduler, UNet2DConditionModel +from diffusers.pipelines.alt_diffusion.modeling_roberta_series import ( + RobertaSeriesConfig, + RobertaSeriesModelWithTransformation, +) +from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu +from transformers import XLMRobertaTokenizer + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class AltDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def dummy_image(self): + batch_size = 1 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) + return image + + @property + def dummy_cond_unet(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + return model + + @property + def dummy_vae(self): + torch.manual_seed(0) + model = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + return model + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = RobertaSeriesConfig( + hidden_size=32, + project_dim=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=5006, + ) + return RobertaSeriesModelWithTransformation(config) + + @property + def dummy_extractor(self): + def extract(*args, **kwargs): + class Out: + def __init__(self): + self.pixel_values = torch.ones([0]) + + def to(self, device): + self.pixel_values.to(device) + return self + + return Out() + + return extract + + def test_stable_diffusion_img2img_default_case(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = XLMRobertaTokenizer.from_pretrained("hf-internal-testing/tiny-xlm-roberta") + tokenizer.model_max_length = 77 + + init_image = self.dummy_image.to(device) + + # make sure here that pndm scheduler skips prk + alt_pipe = AltDiffusionImg2ImgPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + alt_pipe = alt_pipe.to(device) + alt_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = alt_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + init_image=init_image, + ) + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = alt_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + init_image=init_image, + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array( + [0.41293705, 0.38656747, 0.40876025, 0.4782187, 0.4656803, 0.41394007, 0.4142093, 0.47150758, 0.4570448] + ) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1.5e-3 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1.5e-3 + + @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + def test_stable_diffusion_img2img_fp16(self): + """Test that stable diffusion img2img works with fp16""" + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = XLMRobertaTokenizer.from_pretrained("hf-internal-testing/tiny-xlm-roberta") + tokenizer.model_max_length = 77 + + init_image = self.dummy_image.to(torch_device) + + # put models in fp16 + unet = unet.half() + vae = vae.half() + bert = bert.half() + + # make sure here that pndm scheduler skips prk + alt_pipe = AltDiffusionImg2ImgPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + alt_pipe = alt_pipe.to(torch_device) + alt_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = alt_pipe( + [prompt], + generator=generator, + num_inference_steps=2, + output_type="np", + init_image=init_image, + ).images + + assert image.shape == (1, 32, 32, 3) + + +@slow +@require_torch_gpu +class AltDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_stable_diffusion_img2img_pipeline_default(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/sketch-mountains-input.jpg" + ) + init_image = init_image.resize((768, 512)) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/fantasy_landscape_alt.npy" + ) + + model_id = "BAAI/AltDiffusion" + pipe = AltDiffusionImg2ImgPipeline.from_pretrained( + model_id, + safety_checker=None, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "A fantasy landscape, trending on artstation" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + init_image=init_image, + strength=0.75, + guidance_scale=7.5, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 768, 3) + # img2img is flaky across GPUs even in fp32, so using MAE here + assert np.abs(expected_image - image).max() < 1e-3 diff --git a/tests/pipelines/ddim/test_ddim.py b/tests/pipelines/ddim/test_ddim.py index 81c49912be..2d03383599 100644 --- a/tests/pipelines/ddim/test_ddim.py +++ b/tests/pipelines/ddim/test_ddim.py @@ -75,7 +75,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase): model_id = "google/ddpm-ema-bedroom-256" unet = UNet2DModel.from_pretrained(model_id) - scheduler = DDIMScheduler.from_config(model_id) + scheduler = DDIMScheduler.from_pretrained(model_id) ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) ddpm.to(torch_device) diff --git a/tests/pipelines/ddpm/test_ddpm.py b/tests/pipelines/ddpm/test_ddpm.py index 14bc094697..6656fb738d 100644 --- a/tests/pipelines/ddpm/test_ddpm.py +++ b/tests/pipelines/ddpm/test_ddpm.py @@ -68,7 +68,7 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - def test_inference_predict_epsilon(self): + def test_inference_deprecated_predict_epsilon(self): deprecate("remove this test", "0.10.0", "remove") unet = self.dummy_uncond_unet scheduler = DDPMScheduler(predict_epsilon=False) @@ -98,6 +98,35 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): tolerance = 1e-2 if torch_device != "mps" else 3e-2 assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance + def test_inference_predict_sample(self): + unet = self.dummy_uncond_unet + scheduler = DDPMScheduler(prediction_type="sample") + + ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) + ddpm.to(torch_device) + ddpm.set_progress_bar_config(disable=None) + + # Warmup pass when using mps (see #372) + if torch_device == "mps": + _ = ddpm(num_inference_steps=1) + + if torch_device == "mps": + # device type MPS is not supported for torch.Generator() api. + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images + + generator = generator.manual_seed(0) + image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy")[0] + + image_slice = image[0, -3:, -3:, -1] + image_eps_slice = image_eps[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + tolerance = 1e-2 if torch_device != "mps" else 3e-2 + assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance + @slow @require_torch_gpu @@ -106,7 +135,7 @@ class DDPMPipelineIntegrationTests(unittest.TestCase): model_id = "google/ddpm-cifar10-32" unet = UNet2DModel.from_pretrained(model_id) - scheduler = DDPMScheduler.from_config(model_id) + scheduler = DDPMScheduler.from_pretrained(model_id) ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) ddpm.to(torch_device) diff --git a/tests/pipelines/latent_diffusion/test_latent_diffusion.py b/tests/pipelines/latent_diffusion/test_latent_diffusion.py index 085cdb4e76..9d5c07809d 100644 --- a/tests/pipelines/latent_diffusion/test_latent_diffusion.py +++ b/tests/pipelines/latent_diffusion/test_latent_diffusion.py @@ -111,8 +111,8 @@ class LDMTextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_slice = image[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.5074, 0.5026, 0.4998, 0.4056, 0.3523, 0.4649, 0.5289, 0.5299, 0.4897]) + assert image.shape == (1, 16, 16, 3) + expected_slice = np.array([0.6806, 0.5454, 0.5638, 0.4893, 0.4656, 0.4257, 0.6248, 0.5217, 0.5498]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py b/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py index f402d2f2a7..6f1f51c7ba 100644 --- a/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py +++ b/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py @@ -19,9 +19,8 @@ import unittest import numpy as np import torch -import PIL from diffusers import DDIMScheduler, LDMSuperResolutionPipeline, UNet2DModel, VQModel -from diffusers.utils import floats_tensor, load_image, slow, torch_device +from diffusers.utils import PIL_INTERPOLATION, floats_tensor, load_image, slow, torch_device from diffusers.utils.testing_utils import require_torch from ...test_pipelines_common import PipelineTesterMixin @@ -88,6 +87,27 @@ class LDMSuperResolutionPipelineFastTests(PipelineTesterMixin, unittest.TestCase expected_slice = np.array([0.8678, 0.8245, 0.6381, 0.6830, 0.4385, 0.5599, 0.4641, 0.6201, 0.5150]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + def test_inference_superresolution_fp16(self): + unet = self.dummy_uncond_unet + scheduler = DDIMScheduler() + vqvae = self.dummy_vq_model + + # put models in fp16 + unet = unet.half() + vqvae = vqvae.half() + + ldm = LDMSuperResolutionPipeline(unet=unet, vqvae=vqvae, scheduler=scheduler) + ldm.to(torch_device) + ldm.set_progress_bar_config(disable=None) + + init_image = self.dummy_image.to(torch_device) + + generator = torch.Generator(device=torch_device).manual_seed(0) + image = ldm(init_image, generator=generator, num_inference_steps=2, output_type="numpy").images + + assert image.shape == (1, 64, 64, 3) + @slow @require_torch @@ -97,7 +117,7 @@ class LDMSuperResolutionPipelineIntegrationTests(unittest.TestCase): "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/vq_diffusion/teddy_bear_pool.png" ) - init_image = init_image.resize((64, 64), resample=PIL.Image.LANCZOS) + init_image = init_image.resize((64, 64), resample=PIL_INTERPOLATION["lanczos"]) ldm = LDMSuperResolutionPipeline.from_pretrained("duongna/ldm-super-resolution", device_map="auto") ldm.to(torch_device) diff --git a/tests/pipelines/repaint/test_repaint.py b/tests/pipelines/repaint/test_repaint.py index 23544dfd24..3ab0efc875 100644 --- a/tests/pipelines/repaint/test_repaint.py +++ b/tests/pipelines/repaint/test_repaint.py @@ -44,7 +44,7 @@ class RepaintPipelineIntegrationTests(unittest.TestCase): model_id = "google/ddpm-ema-celebahq-256" unet = UNet2DModel.from_pretrained(model_id) - scheduler = RePaintScheduler.from_config(model_id) + scheduler = RePaintScheduler.from_pretrained(model_id) repaint = RePaintPipeline(unet=unet, scheduler=scheduler).to(torch_device) diff --git a/tests/pipelines/score_sde_ve/test_score_sde_ve.py b/tests/pipelines/score_sde_ve/test_score_sde_ve.py index 55dcc1cea1..9cdf3f0191 100644 --- a/tests/pipelines/score_sde_ve/test_score_sde_ve.py +++ b/tests/pipelines/score_sde_ve/test_score_sde_ve.py @@ -74,7 +74,7 @@ class ScoreSdeVePipelineIntegrationTests(unittest.TestCase): model_id = "google/ncsnpp-church-256" model = UNet2DModel.from_pretrained(model_id) - scheduler = ScoreSdeVeScheduler.from_config(model_id) + scheduler = ScoreSdeVeScheduler.from_pretrained(model_id) sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler) sde_ve.to(torch_device) diff --git a/tests/pipelines/stable_diffusion/test_cycle_diffusion.py b/tests/pipelines/stable_diffusion/test_cycle_diffusion.py index de918c7e5c..7a32b74096 100644 --- a/tests/pipelines/stable_diffusion/test_cycle_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_cycle_diffusion.py @@ -281,7 +281,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase): init_image = init_image.resize((512, 512)) model_id = "CompVis/stable-diffusion-v1-4" - scheduler = DDIMScheduler.from_config(model_id, subfolder="scheduler") + scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler") pipe = CycleDiffusionPipeline.from_pretrained( model_id, scheduler=scheduler, safety_checker=None, torch_dtype=torch.float16, revision="fp16" ) @@ -322,7 +322,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase): init_image = init_image.resize((512, 512)) model_id = "CompVis/stable-diffusion-v1-4" - scheduler = DDIMScheduler.from_config(model_id, subfolder="scheduler") + scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler") pipe = CycleDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, safety_checker=None) pipe.to(torch_device) diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py index a1946e39f9..a2b48d27e6 100644 --- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py @@ -75,7 +75,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 def test_inference_ddim(self): - ddim_scheduler = DDIMScheduler.from_config( + ddim_scheduler = DDIMScheduler.from_pretrained( "runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx" ) sd_pipe = OnnxStableDiffusionPipeline.from_pretrained( @@ -98,7 +98,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 def test_inference_k_lms(self): - lms_scheduler = LMSDiscreteScheduler.from_config( + lms_scheduler = LMSDiscreteScheduler.from_pretrained( "runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx" ) sd_pipe = OnnxStableDiffusionPipeline.from_pretrained( diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py index 61831c64c0..91e4412425 100644 --- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py @@ -93,7 +93,7 @@ class OnnxStableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): "/img2img/sketch-mountains-input.jpg" ) init_image = init_image.resize((768, 512)) - lms_scheduler = LMSDiscreteScheduler.from_config( + lms_scheduler = LMSDiscreteScheduler.from_pretrained( "runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx" ) pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained( diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py index 4ba8e273b4..507375bddb 100644 --- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py @@ -97,7 +97,7 @@ class OnnxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase): "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" ) - lms_scheduler = LMSDiscreteScheduler.from_config( + lms_scheduler = LMSDiscreteScheduler.from_pretrained( "runwayml/stable-diffusion-inpainting", subfolder="scheduler", revision="onnx" ) pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained( diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint_legacy.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint_legacy.py new file mode 100644 index 0000000000..577023f705 --- /dev/null +++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint_legacy.py @@ -0,0 +1,95 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from diffusers import OnnxStableDiffusionInpaintPipelineLegacy +from diffusers.utils.testing_utils import ( + is_onnx_available, + load_image, + load_numpy, + require_onnxruntime, + require_torch_gpu, + slow, +) + + +if is_onnx_available(): + import onnxruntime as ort + + +@slow +@require_onnxruntime +@require_torch_gpu +class StableDiffusionOnnxInpaintLegacyPipelineIntegrationTests(unittest.TestCase): + @property + def gpu_provider(self): + return ( + "CUDAExecutionProvider", + { + "gpu_mem_limit": "15000000000", # 15GB + "arena_extend_strategy": "kSameAsRequested", + }, + ) + + @property + def gpu_options(self): + options = ort.SessionOptions() + options.enable_mem_pattern = False + return options + + def test_inference(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" + ) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/red_cat_sitting_on_a_park_bench_onnx.npy" + ) + + # using the PNDM scheduler by default + pipe = OnnxStableDiffusionInpaintPipelineLegacy.from_pretrained( + "CompVis/stable-diffusion-v1-4", + revision="onnx", + provider=self.gpu_provider, + sess_options=self.gpu_options, + ) + pipe.set_progress_bar_config(disable=None) + + prompt = "A red cat sitting on a park bench" + + generator = np.random.RandomState(0) + output = pipe( + prompt=prompt, + init_image=init_image, + mask_image=mask_image, + strength=0.75, + guidance_scale=7.5, + num_inference_steps=15, + generator=generator, + output_type="np", + ) + + image = output.images[0] + + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 1e-2 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 87d238c869..8dce61c3a4 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -209,8 +209,20 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_slice = image[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - assert image.shape == (1, 128, 128, 3) - expected_slice = np.array([0.5112, 0.4692, 0.4715, 0.5206, 0.4894, 0.5114, 0.5096, 0.4932, 0.4755]) + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array( + [ + 0.5643956661224365, + 0.6017904281616211, + 0.4799129366874695, + 0.5267305374145508, + 0.5584856271743774, + 0.46413588523864746, + 0.5159522294998169, + 0.4963662028312683, + 0.47919973731040955, + ] + ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -250,8 +262,8 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): [prompt], generator=generator, guidance_scale=6.0, - height=536, - width=536, + height=136, + width=136, num_inference_steps=2, output_type="np", ) @@ -259,8 +271,8 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 134, 134, 3) - expected_slice = np.array([0.7834, 0.5488, 0.5781, 0.46, 0.3609, 0.5369, 0.542, 0.4855, 0.5557]) + assert image.shape == (1, 136, 136, 3) + expected_slice = np.array([0.5524, 0.5626, 0.6069, 0.4727, 0.386, 0.3995, 0.4613, 0.4328, 0.4269]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -304,8 +316,20 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_slice = image[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - assert image.shape == (1, 128, 128, 3) - expected_slice = np.array([0.4937, 0.4649, 0.4716, 0.5145, 0.4889, 0.513, 0.513, 0.4905, 0.4738]) + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array( + [ + 0.5094760060310364, + 0.5674174427986145, + 0.46675148606300354, + 0.5125715136528015, + 0.5696930289268494, + 0.4674668312072754, + 0.5277683734893799, + 0.4964486062526703, + 0.494540274143219, + ] + ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -370,8 +394,20 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_slice = image[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - assert image.shape == (1, 128, 128, 3) - expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785]) + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array( + [ + 0.47082293033599854, + 0.5371589064598083, + 0.4562119245529175, + 0.5220914483070374, + 0.5733777284622192, + 0.4795039892196655, + 0.5465868711471558, + 0.5074326395988464, + 0.5042197108268738, + ] + ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -415,8 +451,20 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_slice = image[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - assert image.shape == (1, 128, 128, 3) - expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785]) + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array( + [ + 0.4707113206386566, + 0.5372191071510315, + 0.4563021957874298, + 0.5220003724098206, + 0.5734264850616455, + 0.4794946610927582, + 0.5463782548904419, + 0.5074145197868347, + 0.504422664642334, + ] + ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -460,8 +508,20 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_slice = image[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - assert image.shape == (1, 128, 128, 3) - expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785]) + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array( + [ + 0.47082313895225525, + 0.5371587872505188, + 0.4562119245529175, + 0.5220913887023926, + 0.5733776688575745, + 0.47950395941734314, + 0.546586811542511, + 0.5074326992034912, + 0.5042197108268738, + ] + ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -497,6 +557,46 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-4 + def test_stable_diffusion_vae_slicing(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + + image_count = 4 + + generator = torch.Generator(device=device).manual_seed(0) + output_1 = sd_pipe( + [prompt] * image_count, generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np" + ) + + # make sure sliced vae decode yields the same result + sd_pipe.enable_vae_slicing() + generator = torch.Generator(device=device).manual_seed(0) + output_2 = sd_pipe( + [prompt] * image_count, generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np" + ) + + # there is a small discrepancy at image borders vs. full batch decode + assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 3e-3 + def test_stable_diffusion_negative_prompt(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_cond_unet @@ -533,8 +633,20 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image = output.images image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 128, 128, 3) - expected_slice = np.array([0.4851, 0.4617, 0.4765, 0.5127, 0.4845, 0.5153, 0.5141, 0.4886, 0.4719]) + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array( + [ + 0.5108221173286438, + 0.5688379406929016, + 0.4685141146183014, + 0.5098261833190918, + 0.5657756328582764, + 0.4631010890007019, + 0.5226285457611084, + 0.49129390716552734, + 0.4899061322212219, + ] + ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 def test_stable_diffusion_num_images_per_prompt(self): @@ -563,13 +675,13 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): # test num_images_per_prompt=1 (default) images = sd_pipe(prompt, num_inference_steps=2, output_type="np").images - assert images.shape == (1, 128, 128, 3) + assert images.shape == (1, 64, 64, 3) # test num_images_per_prompt=1 (default) for batch of prompts batch_size = 2 images = sd_pipe([prompt] * batch_size, num_inference_steps=2, output_type="np").images - assert images.shape == (batch_size, 128, 128, 3) + assert images.shape == (batch_size, 64, 64, 3) # test num_images_per_prompt for single prompt num_images_per_prompt = 2 @@ -577,7 +689,7 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): prompt, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt ).images - assert images.shape == (num_images_per_prompt, 128, 128, 3) + assert images.shape == (num_images_per_prompt, 64, 64, 3) # test num_images_per_prompt for batch of prompts batch_size = 2 @@ -585,7 +697,7 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): [prompt] * batch_size, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt ).images - assert images.shape == (batch_size * num_images_per_prompt, 128, 128, 3) + assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3) @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") def test_stable_diffusion_fp16(self): @@ -618,7 +730,7 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): generator = torch.Generator(device=torch_device).manual_seed(0) image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images - assert image.shape == (1, 128, 128, 3) + assert image.shape == (1, 64, 64, 3) def test_stable_diffusion_long_prompt(self): unet = self.dummy_cond_unet @@ -671,6 +783,43 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): assert cap_logger.out.count("@") == 25 assert cap_logger_3.out == "" + def test_stable_diffusion_height_width_opt(self): + unet = self.dummy_cond_unet + scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "hey" + + output = sd_pipe(prompt, num_inference_steps=1, output_type="np") + image_shape = output.images[0].shape[:2] + assert image_shape == (64, 64) + + output = sd_pipe(prompt, num_inference_steps=1, height=96, width=96, output_type="np") + image_shape = output.images[0].shape[:2] + assert image_shape == (96, 96) + + config = dict(sd_pipe.unet.config) + config["sample_size"] = 96 + sd_pipe.unet = UNet2DConditionModel.from_config(config).to(torch_device) + output = sd_pipe(prompt, num_inference_steps=1, output_type="np") + image_shape = output.images[0].shape[:2] + assert image_shape == (192, 192) + @slow @require_torch_gpu @@ -703,7 +852,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 def test_stable_diffusion_fast_ddim(self): - scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-1", subfolder="scheduler") + scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-1", subfolder="scheduler") sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", scheduler=scheduler) sd_pipe = sd_pipe.to(torch_device) @@ -726,7 +875,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase): model_id = "CompVis/stable-diffusion-v1-1" pipe = StableDiffusionPipeline.from_pretrained(model_id).to(torch_device) pipe.set_progress_bar_config(disable=None) - scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler") + scheduler = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") pipe.scheduler = scheduler prompt = "a photograph of an astronaut riding a horse" @@ -777,6 +926,45 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase): assert mem_bytes > 3.75 * 10**9 assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3 + def test_stable_diffusion_vae_slicing(self): + torch.cuda.reset_peak_memory_stats() + model_id = "CompVis/stable-diffusion-v1-4" + pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "a photograph of an astronaut riding a horse" + + # enable vae slicing + pipe.enable_vae_slicing() + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + output_chunked = pipe( + [prompt] * 4, generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image_chunked = output_chunked.images + + mem_bytes = torch.cuda.max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + # make sure that less than 4 GB is allocated + assert mem_bytes < 4e9 + + # disable vae slicing + pipe.disable_vae_slicing() + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + output = pipe( + [prompt] * 4, generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image = output.images + + # make sure that more than 4 GB is allocated + mem_bytes = torch.cuda.max_memory_allocated() + assert mem_bytes > 4e9 + # There is a small discrepancy at the image borders vs. a fully batched version. + assert np.abs(image_chunked.flatten() - image.flatten()).max() < 3e-3 + def test_stable_diffusion_text2img_pipeline_fp16(self): torch.cuda.reset_peak_memory_stats() model_id = "CompVis/stable-diffusion-v1-4" @@ -819,7 +1007,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase): prompt = "astronaut riding a horse" generator = torch.Generator(device=torch_device).manual_seed(0) - output = pipe(prompt=prompt, strength=0.75, guidance_scale=7.5, generator=generator, output_type="np") + output = pipe(prompt=prompt, guidance_scale=7.5, generator=generator, output_type="np") image = output.images[0] assert image.shape == (512, 512, 3) @@ -839,7 +1027,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase): expected_slice = np.array( [1.8285, 1.2857, -0.1024, 1.2406, -2.3068, 1.0747, -0.0818, -0.6520, -2.9506] ) - assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3 + assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-3 elif step == 50: latents = latents.detach().cpu().numpy() assert latents.shape == (1, 4, 64, 64) @@ -871,7 +1059,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase): callback_steps=1, ) assert test_callback_fn.has_been_called - assert number_of_steps == 51 + assert number_of_steps == 50 def test_stable_diffusion_low_cpu_mem_usage(self): pipeline_id = "CompVis/stable-diffusion-v1-4" diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py new file mode 100644 index 0000000000..90bfef5efe --- /dev/null +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py @@ -0,0 +1,423 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch + +from diffusers import ( + AutoencoderKL, + LMSDiscreteScheduler, + PNDMScheduler, + StableDiffusionImageVariationPipeline, + UNet2DConditionModel, +) +from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu +from transformers import CLIPVisionConfig, CLIPVisionModelWithProjection + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def dummy_image(self): + batch_size = 1 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) + return image + + @property + def dummy_cond_unet(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + return model + + @property + def dummy_vae(self): + torch.manual_seed(0) + model = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + return model + + @property + def dummy_image_encoder(self): + torch.manual_seed(0) + config = CLIPVisionConfig( + hidden_size=32, + projection_dim=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + image_size=32, + patch_size=4, + ) + return CLIPVisionModelWithProjection(config) + + @property + def dummy_extractor(self): + def extract(*args, **kwargs): + class Out: + def __init__(self): + self.pixel_values = torch.ones([0]) + + def to(self, device): + self.pixel_values.to(device) + return self + + return Out() + + return extract + + def test_stable_diffusion_img_variation_default_case(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + image_encoder = self.dummy_image_encoder + + init_image = self.dummy_image.to(device) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionImageVariationPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + image_encoder=image_encoder, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe( + init_image, + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + ) + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + init_image, + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.5093, 0.5717, 0.4806, 0.4891, 0.5552, 0.4594, 0.5177, 0.4894, 0.4904]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-3 + + def test_stable_diffusion_img_variation_multiple_images(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + image_encoder = self.dummy_image_encoder + + init_image = self.dummy_image.to(device).repeat(2, 1, 1, 1) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionImageVariationPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + image_encoder=image_encoder, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe( + init_image, + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + ) + + image = output.images + + image_slice = image[-1, -3:, -3:, -1] + + assert image.shape == (2, 64, 64, 3) + expected_slice = np.array([0.6427, 0.5452, 0.5602, 0.5478, 0.5968, 0.6211, 0.5538, 0.5514, 0.5281]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_stable_diffusion_img_variation_num_images_per_prompt(self): + device = "cpu" + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + image_encoder = self.dummy_image_encoder + + init_image = self.dummy_image.to(device) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionImageVariationPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + image_encoder=image_encoder, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + # test num_images_per_prompt=1 (default) + images = sd_pipe( + init_image, + num_inference_steps=2, + output_type="np", + ).images + + assert images.shape == (1, 64, 64, 3) + + # test num_images_per_prompt=1 (default) for batch of images + batch_size = 2 + images = sd_pipe( + init_image.repeat(batch_size, 1, 1, 1), + num_inference_steps=2, + output_type="np", + ).images + + assert images.shape == (batch_size, 64, 64, 3) + + # test num_images_per_prompt for single prompt + num_images_per_prompt = 2 + images = sd_pipe( + init_image, + num_inference_steps=2, + output_type="np", + num_images_per_prompt=num_images_per_prompt, + ).images + + assert images.shape == (num_images_per_prompt, 64, 64, 3) + + # test num_images_per_prompt for batch of prompts + batch_size = 2 + images = sd_pipe( + init_image.repeat(batch_size, 1, 1, 1), + num_inference_steps=2, + output_type="np", + num_images_per_prompt=num_images_per_prompt, + ).images + + assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3) + + @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + def test_stable_diffusion_img_variation_fp16(self): + """Test that stable diffusion img2img works with fp16""" + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + image_encoder = self.dummy_image_encoder + + init_image = self.dummy_image.to(torch_device).float() + + # put models in fp16 + unet = unet.half() + vae = vae.half() + image_encoder = image_encoder.half() + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionImageVariationPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + image_encoder=image_encoder, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=torch_device).manual_seed(0) + image = sd_pipe( + init_image, + generator=generator, + num_inference_steps=2, + output_type="np", + ).images + + assert image.shape == (1, 64, 64, 3) + + +@slow +@require_torch_gpu +class StableDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_stable_diffusion_img_variation_pipeline_default(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/vermeer.jpg" + ) + init_image = init_image.resize((512, 512)) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/vermeer.npy" + ) + + model_id = "fusing/sd-image-variations-diffusers" + pipe = StableDiffusionImageVariationPipeline.from_pretrained( + model_id, + safety_checker=None, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + init_image, + guidance_scale=7.5, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 512, 3) + # img2img is flaky across GPUs even in fp32, so using MAE here + assert np.abs(expected_image - image).max() < 1e-3 + + def test_stable_diffusion_img_variation_intermediate_state(self): + number_of_steps = 0 + + def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: + test_callback_fn.has_been_called = True + nonlocal number_of_steps + number_of_steps += 1 + if step == 0: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array([1.83, 1.293, -0.09705, 1.256, -2.293, 1.091, -0.0809, -0.65, -2.953]) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-3 + elif step == 37: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array([2.285, 2.703, 1.969, 0.696, -1.323, 0.9253, -0.5464, -1.521, -2.537]) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2 + + test_callback_fn.has_been_called = False + + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/sketch-mountains-input.jpg" + ) + init_image = init_image.resize((512, 512)) + + pipe = StableDiffusionImageVariationPipeline.from_pretrained( + "fusing/sd-image-variations-diffusers", + torch_dtype=torch.float16, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + pipe( + init_image, + num_inference_steps=50, + guidance_scale=7.5, + generator=generator, + callback=test_callback_fn, + callback_steps=1, + ) + assert test_callback_fn.has_been_called + assert number_of_steps == 50 + + def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/sketch-mountains-input.jpg" + ) + init_image = init_image.resize((512, 512)) + + model_id = "fusing/sd-image-variations-diffusers" + lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") + pipe = StableDiffusionImageVariationPipeline.from_pretrained( + model_id, scheduler=lms, safety_checker=None, torch_dtype=torch.float16 + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing(1) + pipe.enable_sequential_cpu_offload() + + generator = torch.Generator(device=torch_device).manual_seed(0) + _ = pipe( + init_image, + guidance_scale=7.5, + generator=generator, + output_type="np", + num_inference_steps=5, + ) + + mem_bytes = torch.cuda.max_memory_allocated() + # make sure that less than 2.6 GB is allocated + assert mem_bytes < 2.6 * 10**9 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index 3c0fa8aa81..0aa6e79cf8 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -520,7 +520,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): ) model_id = "CompVis/stable-diffusion-v1-4" - lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler") + lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") pipe = StableDiffusionImg2ImgPipeline.from_pretrained( model_id, scheduler=lms, @@ -557,7 +557,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): ) model_id = "CompVis/stable-diffusion-v1-4" - ddim = DDIMScheduler.from_config(model_id, subfolder="scheduler") + ddim = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler") pipe = StableDiffusionImg2ImgPipeline.from_pretrained( model_id, scheduler=ddim, @@ -635,7 +635,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): callback_steps=1, ) assert test_callback_fn.has_been_called - assert number_of_steps == 38 + assert number_of_steps == 37 def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): torch.cuda.empty_cache() @@ -649,7 +649,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): init_image = init_image.resize((768, 512)) model_id = "CompVis/stable-diffusion-v1-4" - lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler") + lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") pipe = StableDiffusionImg2ImgPipeline.from_pretrained( model_id, scheduler=lms, safety_checker=None, device_map="auto", revision="fp16", torch_dtype=torch.float16 ) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index 8d269c38f9..e85fae939e 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -22,12 +22,14 @@ import torch from diffusers import ( AutoencoderKL, + LMSDiscreteScheduler, PNDMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel, UNet2DModel, VQModel, ) +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu from PIL import Image @@ -165,8 +167,8 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] - init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((128, 128)) - mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) + init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64)) # make sure here that pndm scheduler skips prk sd_pipe = StableDiffusionInpaintPipeline( @@ -210,8 +212,9 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test image_slice = image[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - assert image.shape == (1, 128, 128, 3) - expected_slice = np.array([0.5075, 0.4485, 0.4558, 0.5369, 0.5369, 0.5236, 0.5127, 0.4983, 0.4776]) + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4723, 0.5731, 0.3939, 0.5441, 0.5922, 0.4392, 0.5059, 0.4651, 0.4474]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -224,8 +227,8 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] - init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((128, 128)) - mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) + init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64)) # make sure here that pndm scheduler skips prk sd_pipe = StableDiffusionInpaintPipeline( @@ -266,8 +269,8 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] - init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((128, 128)) - mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) + init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64)) # put models in fp16 unet = unet.half() @@ -298,7 +301,7 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test mask_image=mask_image, ).images - assert image.shape == (1, 128, 128, 3) + assert image.shape == (1, 64, 64, 3) @slow @@ -400,7 +403,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase): ) model_id = "runwayml/stable-diffusion-inpainting" - pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler") + pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler") pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None, scheduler=pndm) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) @@ -421,6 +424,45 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase): assert image.shape == (512, 512, 3) assert np.abs(expected_image - image).max() < 1e-2 + def test_stable_diffusion_inpaint_pipeline_k_lms(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" + ) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint" + "/yellow_cat_sitting_on_a_park_bench_k_lms.npy" + ) + + model_id = "runwayml/stable-diffusion-inpainting" + pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None) + pipe.to(torch_device) + + # switch to LMS + pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) + + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + image=init_image, + mask_image=mask_image, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 1e-2 + @unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU") def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): torch.cuda.empty_cache() @@ -437,7 +479,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase): ) model_id = "runwayml/stable-diffusion-inpainting" - pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler") + pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler") pipe = StableDiffusionInpaintPipeline.from_pretrained( model_id, safety_checker=None, @@ -466,3 +508,172 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase): mem_bytes = torch.cuda.max_memory_allocated() # make sure that less than 2.2 GB is allocated assert mem_bytes < 2.2 * 10**9 + + +class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase): + def test_pil_inputs(self): + im = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8) + im = Image.fromarray(im) + mask = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5 + mask = Image.fromarray((mask * 255).astype(np.uint8)) + + t_mask, t_masked = prepare_mask_and_masked_image(im, mask) + + self.assertTrue(isinstance(t_mask, torch.Tensor)) + self.assertTrue(isinstance(t_masked, torch.Tensor)) + + self.assertEqual(t_mask.ndim, 4) + self.assertEqual(t_masked.ndim, 4) + + self.assertEqual(t_mask.shape, (1, 1, 32, 32)) + self.assertEqual(t_masked.shape, (1, 3, 32, 32)) + + self.assertTrue(t_mask.dtype == torch.float32) + self.assertTrue(t_masked.dtype == torch.float32) + + self.assertTrue(t_mask.min() >= 0.0) + self.assertTrue(t_mask.max() <= 1.0) + self.assertTrue(t_masked.min() >= -1.0) + self.assertTrue(t_masked.min() <= 1.0) + + self.assertTrue(t_mask.sum() > 0.0) + + def test_np_inputs(self): + im_np = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8) + im_pil = Image.fromarray(im_np) + mask_np = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5 + mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8)) + + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + t_mask_pil, t_masked_pil = prepare_mask_and_masked_image(im_pil, mask_pil) + + self.assertTrue((t_mask_np == t_mask_pil).all()) + self.assertTrue((t_masked_np == t_masked_pil).all()) + + def test_torch_3D_2D_inputs(self): + im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5 + im_np = im_tensor.numpy().transpose(1, 2, 0) + mask_np = mask_tensor.numpy() + + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + + self.assertTrue((t_mask_tensor == t_mask_np).all()) + self.assertTrue((t_masked_tensor == t_masked_np).all()) + + def test_torch_3D_3D_inputs(self): + im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5 + im_np = im_tensor.numpy().transpose(1, 2, 0) + mask_np = mask_tensor.numpy()[0] + + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + + self.assertTrue((t_mask_tensor == t_mask_np).all()) + self.assertTrue((t_masked_tensor == t_masked_np).all()) + + def test_torch_4D_2D_inputs(self): + im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5 + im_np = im_tensor.numpy()[0].transpose(1, 2, 0) + mask_np = mask_tensor.numpy() + + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + + self.assertTrue((t_mask_tensor == t_mask_np).all()) + self.assertTrue((t_masked_tensor == t_masked_np).all()) + + def test_torch_4D_3D_inputs(self): + im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5 + im_np = im_tensor.numpy()[0].transpose(1, 2, 0) + mask_np = mask_tensor.numpy()[0] + + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + + self.assertTrue((t_mask_tensor == t_mask_np).all()) + self.assertTrue((t_masked_tensor == t_masked_np).all()) + + def test_torch_4D_4D_inputs(self): + im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (1, 1, 32, 32), dtype=torch.uint8) > 127.5 + im_np = im_tensor.numpy()[0].transpose(1, 2, 0) + mask_np = mask_tensor.numpy()[0][0] + + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) + t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + + self.assertTrue((t_mask_tensor == t_mask_np).all()) + self.assertTrue((t_masked_tensor == t_masked_np).all()) + + def test_torch_batch_4D_3D(self): + im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (2, 32, 32), dtype=torch.uint8) > 127.5 + + im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor] + mask_nps = [mask.numpy() for mask in mask_tensor] + + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) + nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)] + t_mask_np = torch.cat([n[0] for n in nps]) + t_masked_np = torch.cat([n[1] for n in nps]) + + self.assertTrue((t_mask_tensor == t_mask_np).all()) + self.assertTrue((t_masked_tensor == t_masked_np).all()) + + def test_torch_batch_4D_4D(self): + im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8) + mask_tensor = torch.randint(0, 255, (2, 1, 32, 32), dtype=torch.uint8) > 127.5 + + im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor] + mask_nps = [mask.numpy()[0] for mask in mask_tensor] + + t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) + nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)] + t_mask_np = torch.cat([n[0] for n in nps]) + t_masked_np = torch.cat([n[1] for n in nps]) + + self.assertTrue((t_mask_tensor == t_mask_np).all()) + self.assertTrue((t_masked_tensor == t_masked_np).all()) + + def test_shape_mismatch(self): + # test height and width + with self.assertRaises(AssertionError): + prepare_mask_and_masked_image(torch.randn(3, 32, 32), torch.randn(64, 64)) + # test batch dim + with self.assertRaises(AssertionError): + prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 64, 64)) + # test batch dim + with self.assertRaises(AssertionError): + prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 1, 64, 64)) + + def test_type_mismatch(self): + # test tensors-only + with self.assertRaises(TypeError): + prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.rand(3, 32, 32).numpy()) + # test tensors-only + with self.assertRaises(TypeError): + prepare_mask_and_masked_image(torch.rand(3, 32, 32).numpy(), torch.rand(3, 32, 32)) + + def test_channels_first(self): + # test channels first for 3D tensors + with self.assertRaises(AssertionError): + prepare_mask_and_masked_image(torch.rand(32, 32, 3), torch.rand(3, 32, 32)) + + def test_tensor_range(self): + # test im <= 1 + with self.assertRaises(ValueError): + prepare_mask_and_masked_image(torch.ones(3, 32, 32) * 2, torch.rand(32, 32)) + # test im >= -1 + with self.assertRaises(ValueError): + prepare_mask_and_masked_image(torch.ones(3, 32, 32) * (-2), torch.rand(32, 32)) + # test mask <= 1 + with self.assertRaises(ValueError): + prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * 2) + # test mask >= 0 + with self.assertRaises(ValueError): + prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * -1) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py index 4b535dc9df..b719566b5e 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py @@ -168,7 +168,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittes image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] init_image = Image.fromarray(np.uint8(image)).convert("RGB") - mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32)) # make sure here that pndm scheduler skips prk sd_pipe = StableDiffusionInpaintPipelineLegacy( @@ -227,7 +227,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittes image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] init_image = Image.fromarray(np.uint8(image)).convert("RGB") - mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32)) # make sure here that pndm scheduler skips prk sd_pipe = StableDiffusionInpaintPipelineLegacy( @@ -273,7 +273,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittes image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] init_image = Image.fromarray(np.uint8(image)).convert("RGB") - mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32)) # make sure here that pndm scheduler skips prk sd_pipe = StableDiffusionInpaintPipelineLegacy( @@ -401,7 +401,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase): ) model_id = "CompVis/stable-diffusion-v1-4" - lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler") + lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") pipe = StableDiffusionInpaintPipeline.from_pretrained( model_id, scheduler=lms, @@ -484,4 +484,4 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase): callback_steps=1, ) assert test_callback_fn.has_been_called - assert number_of_steps == 38 + assert number_of_steps == 37 diff --git a/tests/pipelines/stable_diffusion_2/__init__.py b/tests/pipelines/stable_diffusion_2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py new file mode 100644 index 0000000000..efa4bdc6f3 --- /dev/null +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py @@ -0,0 +1,733 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import time +import unittest + +import numpy as np +import torch + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, + logging, +) +from diffusers.utils import load_numpy, slow, torch_device +from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class StableDiffusion2PipelineFastTests(PipelineTesterMixin, unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def dummy_cond_unet(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + # SD2-specific config below + attention_head_dim=(2, 4, 8, 8), + use_linear_projection=True, + ) + return model + + @property + def dummy_vae(self): + torch.manual_seed(0) + model = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=128, + ) + return model + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + # SD2-specific config below + hidden_act="gelu", + projection_dim=512, + ) + return CLIPTextModel(config) + + def test_save_pretrained_from_pretrained(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + image = output.images + + with tempfile.TemporaryDirectory() as tmpdirname: + sd_pipe.save_pretrained(tmpdirname) + sd_pipe = StableDiffusionPipeline.from_pretrained(tmpdirname) + sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + generator = generator.manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + new_image = output.images + + assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass" + + def test_stable_diffusion_ddim(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.5649, 0.6022, 0.4804, 0.5270, 0.5585, 0.4643, 0.5159, 0.4963, 0.4793]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_pndm(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.5099, 0.5677, 0.4671, 0.5128, 0.5697, 0.4676, 0.5277, 0.4964, 0.4946]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_k_lms(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4717, 0.5376, 0.4568, 0.5225, 0.5734, 0.4797, 0.5467, 0.5074, 0.5043]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_k_euler_ancestral(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = EulerAncestralDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4715, 0.5376, 0.4569, 0.5224, 0.5734, 0.4797, 0.5465, 0.5074, 0.5046]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_k_euler(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = EulerDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4717, 0.5376, 0.4568, 0.5225, 0.5734, 0.4797, 0.5467, 0.5074, 0.5043]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_attention_chunk(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output_1 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + # make sure chunking the attention yields the same result + sd_pipe.enable_attention_slicing(slice_size=1) + generator = torch.Generator(device=device).manual_seed(0) + output_2 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-4 + + @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + def test_stable_diffusion_fp16(self): + """Test that stable diffusion works with fp16""" + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # put models in fp16 + unet = unet.half() + vae = vae.half() + bert = bert.half() + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images + + assert image.shape == (1, 64, 64, 3) + + def test_stable_diffusion_long_prompt(self): + unet = self.dummy_cond_unet + scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + do_classifier_free_guidance = True + negative_prompt = None + num_images_per_prompt = 1 + logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion") + + prompt = 25 * "@" + with CaptureLogger(logger) as cap_logger_3: + text_embeddings_3 = sd_pipe._encode_prompt( + prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + prompt = 100 * "@" + with CaptureLogger(logger) as cap_logger: + text_embeddings = sd_pipe._encode_prompt( + prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + negative_prompt = "Hello" + with CaptureLogger(logger) as cap_logger_2: + text_embeddings_2 = sd_pipe._encode_prompt( + prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + assert text_embeddings_3.shape == text_embeddings_2.shape == text_embeddings.shape + assert text_embeddings.shape[1] == 77 + + assert cap_logger.out == cap_logger_2.out + # 100 - 77 + 1 (BOS token) + 1 (EOS token) = 25 + assert cap_logger.out.count("@") == 25 + assert cap_logger_3.out == "" + + +@slow +@require_torch_gpu +class StableDiffusion2PipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_stable_diffusion(self): + sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base") + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np") + + image = output.images + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.0788, 0.0823, 0.1091, 0.1165, 0.1263, 0.1459, 0.1317, 0.1507, 0.1551]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_ddim(self): + scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-base", subfolder="scheduler") + sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base", scheduler=scheduler) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + + output = sd_pipe([prompt], generator=generator, num_inference_steps=5, output_type="numpy") + image = output.images + + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.0642, 0.0382, 0.0408, 0.0395, 0.0227, 0.0942, 0.0749, 0.0669, 0.0248]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_k_lms(self): + scheduler = LMSDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-2-base", subfolder="scheduler") + sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base", scheduler=scheduler) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "a photograph of an astronaut riding a horse" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = sd_pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=5, output_type="numpy" + ).images + + image_slice = image[0, 253:256, 253:256, -1] + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.0548, 0.0626, 0.0612, 0.0611, 0.0706, 0.0586, 0.0843, 0.0333, 0.1197]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_attention_slicing(self): + torch.cuda.reset_peak_memory_stats() + model_id = "stabilityai/stable-diffusion-2-base" + pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "a photograph of an astronaut riding a horse" + + # make attention efficient + pipe.enable_attention_slicing() + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + output_chunked = pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image_chunked = output_chunked.images + + mem_bytes = torch.cuda.max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + # make sure that less than 3.75 GB is allocated + assert mem_bytes < 3.75 * 10**9 + + # disable chunking + pipe.disable_attention_slicing() + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + output = pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image = output.images + + # make sure that more than 3.75 GB is allocated + mem_bytes = torch.cuda.max_memory_allocated() + assert mem_bytes > 3.75 * 10**9 + assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3 + + def test_stable_diffusion_same_quality(self): + torch.cuda.reset_peak_memory_stats() + model_id = "stabilityai/stable-diffusion-2-base" + pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16) + pipe = pipe.to(torch_device) + pipe.enable_attention_slicing() + pipe.set_progress_bar_config(disable=None) + + prompt = "a photograph of an astronaut riding a horse" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output_chunked = pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image_chunked = output_chunked.images + + pipe = StableDiffusionPipeline.from_pretrained(model_id) + pipe = pipe.to(torch_device) + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy") + image = output.images + + # Make sure results are close enough + diff = np.abs(image_chunked.flatten() - image.flatten()) + # They ARE different since ops are not run always at the same precision + # however, they should be extremely close. + assert diff.mean() < 5e-2 + + def test_stable_diffusion_text2img_pipeline_default(self): + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-text2img/astronaut_riding_a_horse.npy" + ) + + model_id = "stabilityai/stable-diffusion-2-base" + pipe = StableDiffusionPipeline.from_pretrained(model_id, safety_checker=None) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "astronaut riding a horse" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe(prompt=prompt, guidance_scale=7.5, generator=generator, output_type="np") + image = output.images[0] + + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 5e-3 + + def test_stable_diffusion_text2img_intermediate_state(self): + number_of_steps = 0 + + def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: + test_callback_fn.has_been_called = True + nonlocal number_of_steps + number_of_steps += 1 + if step == 0: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array([1.8606, 1.3169, -0.0691, 1.2374, -2.309, 1.077, -0.1084, -0.6774, -2.9594]) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-3 + elif step == 20: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 64, 64) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array([1.0757, 1.1860, 1.1410, 0.4645, -0.2476, 0.6100, -0.7755, -0.8841, -0.9497]) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2 + + test_callback_fn.has_been_called = False + + pipe = StableDiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-base", revision="fp16", torch_dtype=torch.float16 + ) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "Andromeda galaxy in a bottle" + + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + pipe( + prompt=prompt, + num_inference_steps=20, + guidance_scale=7.5, + generator=generator, + callback=test_callback_fn, + callback_steps=1, + ) + assert test_callback_fn.has_been_called + assert number_of_steps == 20 + + def test_stable_diffusion_low_cpu_mem_usage(self): + pipeline_id = "stabilityai/stable-diffusion-2-base" + + start_time = time.time() + pipeline_low_cpu_mem_usage = StableDiffusionPipeline.from_pretrained( + pipeline_id, revision="fp16", torch_dtype=torch.float16 + ) + pipeline_low_cpu_mem_usage.to(torch_device) + low_cpu_mem_usage_time = time.time() - start_time + + start_time = time.time() + _ = StableDiffusionPipeline.from_pretrained( + pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, low_cpu_mem_usage=False + ) + normal_load_time = time.time() - start_time + + assert 2 * low_cpu_mem_usage_time < normal_load_time + + def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + pipeline_id = "stabilityai/stable-diffusion-2-base" + prompt = "Andromeda galaxy in a bottle" + + pipeline = StableDiffusionPipeline.from_pretrained(pipeline_id, revision="fp16", torch_dtype=torch.float16) + pipeline = pipeline.to(torch_device) + pipeline.enable_attention_slicing(1) + pipeline.enable_sequential_cpu_offload() + + generator = torch.Generator(device=torch_device).manual_seed(0) + _ = pipeline(prompt, generator=generator, num_inference_steps=5) + + mem_bytes = torch.cuda.max_memory_allocated() + # make sure that less than 2.8 GB is allocated + assert mem_bytes < 2.8 * 10**9 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py new file mode 100644 index 0000000000..f10f0e1798 --- /dev/null +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py @@ -0,0 +1,99 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +from diffusers import FlaxDPMSolverMultistepScheduler, FlaxStableDiffusionPipeline +from diffusers.utils import is_flax_available, slow +from diffusers.utils.testing_utils import require_flax + + +if is_flax_available(): + import jax + import jax.numpy as jnp + from flax.jax_utils import replicate + from flax.training.common_utils import shard + + +@slow +@require_flax +class FlaxStableDiffusion2PipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + + def test_stable_diffusion_flax(self): + sd_pipe, params = FlaxStableDiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-2", + revision="bf16", + dtype=jnp.bfloat16, + ) + + prompt = "A painting of a squirrel eating a burger" + num_samples = jax.device_count() + prompt = num_samples * [prompt] + prompt_ids = sd_pipe.prepare_inputs(prompt) + + params = replicate(params) + prompt_ids = shard(prompt_ids) + + prng_seed = jax.random.PRNGKey(0) + prng_seed = jax.random.split(prng_seed, jax.device_count()) + + images = sd_pipe(prompt_ids, params, prng_seed, num_inference_steps=25, jit=True)[0] + assert images.shape == (jax.device_count(), 1, 768, 768, 3) + + images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) + image_slice = images[0, 253:256, 253:256, -1] + + output_slice = jnp.asarray(jax.device_get(image_slice.flatten())) + expected_slice = jnp.array([0.4238, 0.4414, 0.4395, 0.4453, 0.4629, 0.4590, 0.4531, 0.45508, 0.4512]) + print(f"output_slice: {output_slice}") + assert jnp.abs(output_slice - expected_slice).max() < 1e-2 + + def test_stable_diffusion_dpm_flax(self): + model_id = "stabilityai/stable-diffusion-2" + scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler") + sd_pipe, params = FlaxStableDiffusionPipeline.from_pretrained( + model_id, + scheduler=scheduler, + revision="bf16", + dtype=jnp.bfloat16, + ) + params["scheduler"] = scheduler_params + + prompt = "A painting of a squirrel eating a burger" + num_samples = jax.device_count() + prompt = num_samples * [prompt] + prompt_ids = sd_pipe.prepare_inputs(prompt) + + params = replicate(params) + prompt_ids = shard(prompt_ids) + + prng_seed = jax.random.PRNGKey(0) + prng_seed = jax.random.split(prng_seed, jax.device_count()) + + images = sd_pipe(prompt_ids, params, prng_seed, num_inference_steps=25, jit=True)[0] + assert images.shape == (jax.device_count(), 1, 768, 768, 3) + + images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) + image_slice = images[0, 253:256, 253:256, -1] + + output_slice = jnp.asarray(jax.device_get(image_slice.flatten())) + expected_slice = jnp.array([0.4336, 0.42969, 0.4453, 0.4199, 0.4297, 0.4531, 0.4434, 0.4434, 0.4297]) + print(f"output_slice: {output_slice}") + assert jnp.abs(output_slice - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py new file mode 100644 index 0000000000..b420570f07 --- /dev/null +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py @@ -0,0 +1,345 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch + +from diffusers import AutoencoderKL, PNDMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel +from diffusers.utils import floats_tensor, load_image, load_numpy, torch_device +from diffusers.utils.testing_utils import require_torch_gpu +from PIL import Image +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def dummy_image(self): + batch_size = 1 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) + return image + + @property + def dummy_cond_unet_inpaint(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=9, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + # SD2-specific config below + attention_head_dim=(2, 4, 8, 8), + use_linear_projection=True, + ) + return model + + @property + def dummy_vae(self): + torch.manual_seed(0) + model = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + return model + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + # SD2-specific config below + hidden_act="gelu", + projection_dim=512, + ) + return CLIPTextModel(config) + + @property + def dummy_extractor(self): + def extract(*args, **kwargs): + class Out: + def __init__(self): + self.pixel_values = torch.ones([0]) + + def to(self, device): + self.pixel_values.to(device) + return self + + return Out() + + return extract + + def test_stable_diffusion_inpaint(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet_inpaint + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + text_encoder = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] + init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64)) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionInpaintPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + image=init_image, + mask_image=mask_image, + ) + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + image=init_image, + mask_image=mask_image, + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4727, 0.5735, 0.3941, 0.5446, 0.5926, 0.4394, 0.5062, 0.4654, 0.4476]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + def test_stable_diffusion_inpaint_fp16(self): + """Test that stable diffusion inpaint works with fp16""" + unet = self.dummy_cond_unet_inpaint + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + text_encoder = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] + init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64)) + + # put models in fp16 + unet = unet.half() + vae = vae.half() + text_encoder = text_encoder.half() + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionInpaintPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = sd_pipe( + [prompt], + generator=generator, + num_inference_steps=2, + output_type="np", + image=init_image, + mask_image=mask_image, + ).images + + assert image.shape == (1, 64, 64, 3) + + +# @slow +@require_torch_gpu +class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_stable_diffusion_inpaint_pipeline(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/sd2-inpaint/init_image.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png" + ) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint" + "/yellow_cat_sitting_on_a_park_bench.npy" + ) + + model_id = "stabilityai/stable-diffusion-2-inpainting" + pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + image=init_image, + mask_image=mask_image, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 1e-3 + + def test_stable_diffusion_inpaint_pipeline_fp16(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/sd2-inpaint/init_image.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png" + ) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint" + "/yellow_cat_sitting_on_a_park_bench_fp16.npy" + ) + + model_id = "stabilityai/stable-diffusion-2-inpainting" + pipe = StableDiffusionInpaintPipeline.from_pretrained( + model_id, + revision="fp16", + torch_dtype=torch.float16, + safety_checker=None, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + image=init_image, + mask_image=mask_image, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 5e-1 + + def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/sd2-inpaint/init_image.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png" + ) + + model_id = "stabilityai/stable-diffusion-2-inpainting" + pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler") + pipe = StableDiffusionInpaintPipeline.from_pretrained( + model_id, + safety_checker=None, + scheduler=pndm, + device_map="auto", + revision="fp16", + torch_dtype=torch.float16, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing(1) + pipe.enable_sequential_cpu_offload() + + prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + _ = pipe( + prompt=prompt, + image=init_image, + mask_image=mask_image, + generator=generator, + num_inference_steps=5, + output_type="np", + ) + + mem_bytes = torch.cuda.max_memory_allocated() + # make sure that less than 2.65 GB is allocated + assert mem_bytes < 2.65 * 10**9 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py new file mode 100644 index 0000000000..2092e153ee --- /dev/null +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py @@ -0,0 +1,315 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch + +from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, StableDiffusionUpscalePipeline, UNet2DConditionModel +from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu +from PIL import Image +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class StableDiffusionUpscalePipelineFastTests(PipelineTesterMixin, unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def dummy_image(self): + batch_size = 1 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) + return image + + @property + def dummy_cond_unet_upscale(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 32, 64), + layers_per_block=2, + sample_size=32, + in_channels=7, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + # SD2-specific config below + attention_head_dim=8, + use_linear_projection=True, + only_cross_attention=(True, True, False), + num_class_embeds=100, + ) + return model + + @property + def dummy_vae(self): + torch.manual_seed(0) + model = AutoencoderKL( + block_out_channels=[32, 32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + return model + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + # SD2-specific config below + hidden_act="gelu", + projection_dim=512, + ) + return CLIPTextModel(config) + + def test_stable_diffusion_upscale(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet_upscale + low_res_scheduler = DDPMScheduler() + scheduler = DDIMScheduler(prediction_type="v_prediction") + vae = self.dummy_vae + text_encoder = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] + low_res_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionUpscalePipeline( + unet=unet, + low_res_scheduler=low_res_scheduler, + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + max_noise_level=350, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe( + [prompt], + image=low_res_image, + generator=generator, + guidance_scale=6.0, + noise_level=20, + num_inference_steps=2, + output_type="np", + ) + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + image=low_res_image, + generator=generator, + guidance_scale=6.0, + noise_level=20, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + expected_height_width = low_res_image.size[0] * 4 + assert image.shape == (1, expected_height_width, expected_height_width, 3) + expected_slice = np.array([0.2562, 0.3606, 0.4204, 0.4469, 0.4822, 0.4647, 0.5315, 0.5748, 0.5606]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + def test_stable_diffusion_upscale_fp16(self): + """Test that stable diffusion upscale works with fp16""" + unet = self.dummy_cond_unet_upscale + low_res_scheduler = DDPMScheduler() + scheduler = DDIMScheduler(prediction_type="v_prediction") + vae = self.dummy_vae + text_encoder = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] + low_res_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + + # put models in fp16, except vae as it overflows in fp16 + unet = unet.half() + text_encoder = text_encoder.half() + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionUpscalePipeline( + unet=unet, + low_res_scheduler=low_res_scheduler, + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + max_noise_level=350, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = sd_pipe( + [prompt], + image=low_res_image, + generator=generator, + num_inference_steps=2, + output_type="np", + ).images + + expected_height_width = low_res_image.size[0] * 4 + assert image.shape == (1, expected_height_width, expected_height_width, 3) + + +@slow +@require_torch_gpu +class StableDiffusionUpscalePipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_stable_diffusion_upscale_pipeline(self): + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/sd2-upscale/low_res_cat.png" + ) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale" + "/upsampled_cat.npy" + ) + + model_id = "stabilityai/stable-diffusion-x4-upscaler" + pipe = StableDiffusionUpscalePipeline.from_pretrained(model_id) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "a cat sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + image=image, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 1e-3 + + def test_stable_diffusion_upscale_pipeline_fp16(self): + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/sd2-upscale/low_res_cat.png" + ) + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale" + "/upsampled_cat_fp16.npy" + ) + + model_id = "stabilityai/stable-diffusion-x4-upscaler" + pipe = StableDiffusionUpscalePipeline.from_pretrained( + model_id, + revision="fp16", + torch_dtype=torch.float16, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "a cat sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + image=image, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 5e-1 + + def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/sd2-upscale/low_res_cat.png" + ) + + model_id = "stabilityai/stable-diffusion-x4-upscaler" + pipe = StableDiffusionUpscalePipeline.from_pretrained( + model_id, + revision="fp16", + torch_dtype=torch.float16, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing(1) + pipe.enable_sequential_cpu_offload() + + prompt = "a cat sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + _ = pipe( + prompt=prompt, + image=image, + generator=generator, + num_inference_steps=5, + output_type="np", + ) + + mem_bytes = torch.cuda.max_memory_allocated() + # make sure that less than 2.65 GB is allocated + assert mem_bytes < 2.65 * 10**9 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py new file mode 100644 index 0000000000..bbe4f49436 --- /dev/null +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py @@ -0,0 +1,474 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import time +import unittest + +import numpy as np +import torch + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerDiscreteScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.utils import load_numpy, slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class StableDiffusion2VPredictionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def dummy_cond_unet(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + # SD2-specific config below + attention_head_dim=(2, 4, 8, 8), + use_linear_projection=True, + ) + return model + + @property + def dummy_vae(self): + torch.manual_seed(0) + model = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=128, + ) + return model + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + # SD2-specific config below + hidden_act="gelu", + projection_dim=64, + ) + return CLIPTextModel(config) + + def test_stable_diffusion_v_pred_ddim(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + prediction_type="v_prediction", + ) + + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.6424, 0.6109, 0.494, 0.5088, 0.4984, 0.4525, 0.5059, 0.5068, 0.4474]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_v_pred_k_euler(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = EulerDiscreteScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="v_prediction" + ) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4616, 0.5184, 0.4887, 0.5111, 0.4839, 0.48, 0.5119, 0.5263, 0.4776]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + def test_stable_diffusion_v_pred_fp16(self): + """Test that stable diffusion v-prediction works with fp16""" + unet = self.dummy_cond_unet + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + prediction_type="v_prediction", + ) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # put models in fp16 + unet = unet.half() + vae = vae.half() + bert = bert.half() + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images + + assert image.shape == (1, 64, 64, 3) + + +@slow +@require_torch_gpu +class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_stable_diffusion_v_pred_default(self): + sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2") + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.enable_attention_slicing() + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=20, output_type="np") + + image = output.images + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 768, 768, 3) + expected_slice = np.array([0.0567, 0.057, 0.0416, 0.0463, 0.0433, 0.06, 0.0517, 0.0526, 0.0866]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_v_pred_euler(self): + scheduler = EulerDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler") + sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", scheduler=scheduler) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.enable_attention_slicing() + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + + output = sd_pipe([prompt], generator=generator, num_inference_steps=5, output_type="numpy") + image = output.images + + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 768, 768, 3) + expected_slice = np.array([0.0351, 0.0376, 0.0505, 0.0424, 0.0551, 0.0656, 0.0471, 0.0276, 0.0596]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_v_pred_dpm(self): + """ + TODO: update this test after making DPM compatible with V-prediction! + """ + scheduler = DPMSolverMultistepScheduler.from_pretrained( + "stabilityai/stable-diffusion-2", subfolder="scheduler" + ) + sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", scheduler=scheduler) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.enable_attention_slicing() + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "a photograph of an astronaut riding a horse" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = sd_pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=5, output_type="numpy" + ).images + + image_slice = image[0, 253:256, 253:256, -1] + assert image.shape == (1, 768, 768, 3) + expected_slice = np.array([0.2049, 0.2115, 0.2323, 0.2416, 0.256, 0.2484, 0.2517, 0.2358, 0.236]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_attention_slicing_v_pred(self): + torch.cuda.reset_peak_memory_stats() + model_id = "stabilityai/stable-diffusion-2" + pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "a photograph of an astronaut riding a horse" + + # make attention efficient + pipe.enable_attention_slicing() + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + output_chunked = pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image_chunked = output_chunked.images + + mem_bytes = torch.cuda.max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + # make sure that less than 5.5 GB is allocated + assert mem_bytes < 5.5 * 10**9 + + # disable slicing + pipe.disable_attention_slicing() + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + output = pipe( + [prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy" + ) + image = output.images + + # make sure that more than 5.5 GB is allocated + mem_bytes = torch.cuda.max_memory_allocated() + assert mem_bytes > 5.5 * 10**9 + assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3 + + def test_stable_diffusion_text2img_pipeline_v_pred_default(self): + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/" + "sd2-text2img/astronaut_riding_a_horse_v_pred.npy" + ) + + pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2") + pipe.to(torch_device) + pipe.enable_attention_slicing() + pipe.set_progress_bar_config(disable=None) + + prompt = "astronaut riding a horse" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe(prompt=prompt, guidance_scale=7.5, generator=generator, output_type="np") + image = output.images[0] + + assert image.shape == (768, 768, 3) + assert np.abs(expected_image - image).max() < 5e-3 + + def test_stable_diffusion_text2img_pipeline_v_pred_fp16(self): + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/" + "sd2-text2img/astronaut_riding_a_horse_v_pred_fp16.npy" + ) + + pipe = StableDiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-2", revision="fp16", torch_dtype=torch.float16 + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "astronaut riding a horse" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe(prompt=prompt, guidance_scale=7.5, generator=generator, output_type="np") + image = output.images[0] + + assert image.shape == (768, 768, 3) + assert np.abs(expected_image - image).max() < 5e-1 + + def test_stable_diffusion_text2img_intermediate_state_v_pred(self): + number_of_steps = 0 + + def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None: + test_callback_fn.has_been_called = True + nonlocal number_of_steps + number_of_steps += 1 + if step == 0: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 96, 96) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array( + [-0.2543, -1.2755, 0.4261, -0.9555, -1.173, -0.5892, 2.4159, 0.1554, -1.2098] + ) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-3 + elif step == 19: + latents = latents.detach().cpu().numpy() + assert latents.shape == (1, 4, 96, 96) + latents_slice = latents[0, -3:, -3:, -1] + expected_slice = np.array( + [-0.9572, -0.967, -0.6152, 0.0894, -0.699, -0.2344, 1.5465, -0.0357, -0.1141] + ) + assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2 + + test_callback_fn.has_been_called = False + + pipe = StableDiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-2", revision="fp16", torch_dtype=torch.float16 + ) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "Andromeda galaxy in a bottle" + + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast(torch_device): + pipe( + prompt=prompt, + num_inference_steps=20, + guidance_scale=7.5, + generator=generator, + callback=test_callback_fn, + callback_steps=1, + ) + assert test_callback_fn.has_been_called + assert number_of_steps == 20 + + def test_stable_diffusion_low_cpu_mem_usage_v_pred(self): + pipeline_id = "stabilityai/stable-diffusion-2" + + start_time = time.time() + pipeline_low_cpu_mem_usage = StableDiffusionPipeline.from_pretrained( + pipeline_id, revision="fp16", torch_dtype=torch.float16 + ) + pipeline_low_cpu_mem_usage.to(torch_device) + low_cpu_mem_usage_time = time.time() - start_time + + start_time = time.time() + _ = StableDiffusionPipeline.from_pretrained( + pipeline_id, revision="fp16", torch_dtype=torch.float16, low_cpu_mem_usage=False + ) + normal_load_time = time.time() - start_time + + assert 2 * low_cpu_mem_usage_time < normal_load_time + + def test_stable_diffusion_pipeline_with_sequential_cpu_offloading_v_pred(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + pipeline_id = "stabilityai/stable-diffusion-2" + prompt = "Andromeda galaxy in a bottle" + + pipeline = StableDiffusionPipeline.from_pretrained(pipeline_id, revision="fp16", torch_dtype=torch.float16) + pipeline = pipeline.to(torch_device) + pipeline.enable_attention_slicing(1) + pipeline.enable_sequential_cpu_offload() + + generator = torch.Generator(device=torch_device).manual_seed(0) + _ = pipeline(prompt, generator=generator, num_inference_steps=5) + + mem_bytes = torch.cuda.max_memory_allocated() + # make sure that less than 2.8 GB is allocated + assert mem_bytes < 2.8 * 10**9 diff --git a/tests/pipelines/stable_diffusion_safe/__init__.py b/tests/pipelines/stable_diffusion_safe/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py b/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py new file mode 100644 index 0000000000..dbb9914793 --- /dev/null +++ b/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py @@ -0,0 +1,435 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import tempfile +import unittest + +import numpy as np +import torch + +from diffusers import AutoencoderKL, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion_safe import StableDiffusionPipelineSafe as StableDiffusionPipeline +from diffusers.utils import floats_tensor, slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class SafeDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def dummy_image(self): + batch_size = 1 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) + return image + + @property + def dummy_cond_unet(self): + torch.manual_seed(0) + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + return model + + @property + def dummy_vae(self): + torch.manual_seed(0) + model = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + return model + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + return CLIPTextModel(config) + + @property + def dummy_extractor(self): + def extract(*args, **kwargs): + class Out: + def __init__(self): + self.pixel_values = torch.ones([0]) + + def to(self, device): + self.pixel_values.to(device) + return self + + return Out() + + return extract + + def test_safe_diffusion_ddim(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.5644, 0.6018, 0.4799, 0.5267, 0.5585, 0.4641, 0.516, 0.4964, 0.4792]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_pndm(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.5095, 0.5674, 0.4668, 0.5126, 0.5697, 0.4675, 0.5278, 0.4964, 0.4945]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_no_safety_checker(self): + pipe = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None + ) + assert isinstance(pipe, StableDiffusionPipeline) + assert isinstance(pipe.scheduler, LMSDiscreteScheduler) + assert pipe.safety_checker is None + + image = pipe("example prompt", num_inference_steps=2).images[0] + assert image is not None + + # check that there's no error when saving a pipeline with one of the models being None + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + pipe = StableDiffusionPipeline.from_pretrained(tmpdirname) + + # sanity check that the pipeline still works + assert pipe.safety_checker is None + image = pipe("example prompt", num_inference_steps=2).images[0] + assert image is not None + + @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") + def test_stable_diffusion_fp16(self): + """Test that stable diffusion works with fp16""" + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # put models in fp16 + unet = unet.half() + vae = vae.half() + bert = bert.half() + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images + + assert image.shape == (1, 64, 64, 3) + + +@slow +@require_torch_gpu +class SafeDiffusionPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_harm_safe_stable_diffusion(self): + sd_pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None) + sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = ( + "portrait of girl with smokey eyes makeup in abandoned hotel, grange clothes, redshift, wide high angle" + " coloured polaroid photograph with flash, kodak film, hyper real, stunning moody cinematography, with" + " anamorphic lenses, by maripol, fallen angels by wong kar - wai, style of suspiria and neon demon and" + " children from bahnhof zoo, detailed " + ) + seed = 4003660346 + guidance_scale = 7 + + # without safety guidance (sld_guidance_scale = 0) + generator = torch.Generator(device=torch_device).manual_seed(seed) + output = sd_pipe( + [prompt], + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=50, + output_type="np", + width=512, + height=512, + sld_guidance_scale=0, + ) + + image = output.images + image_slice = image[0, -3:, -3:, -1] + expected_slice = [0.2278, 0.2231, 0.2249, 0.2333, 0.2303, 0.1885, 0.2273, 0.2144, 0.2176] + + assert image.shape == (1, 512, 512, 3) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + # without safety guidance (strong configuration) + generator = torch.Generator(device=torch_device).manual_seed(seed) + output = sd_pipe( + [prompt], + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=50, + output_type="np", + width=512, + height=512, + sld_guidance_scale=2000, + sld_warmup_steps=7, + sld_threshold=0.025, + sld_momentum_scale=0.5, + sld_mom_beta=0.7, + ) + + image = output.images + image_slice = image[0, -3:, -3:, -1] + expected_slice = [0.2383, 0.2276, 0.236, 0.2192, 0.2186, 0.2053, 0.1971, 0.1901, 0.1719] + + assert image.shape == (1, 512, 512, 3) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_nudity_safe_stable_diffusion(self): + sd_pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None) + sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "padme amidala taking a bath artwork, safe for work, no nudity" + seed = 2734971755 + guidance_scale = 7 + + generator = torch.Generator(device=torch_device).manual_seed(seed) + output = sd_pipe( + [prompt], + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=50, + output_type="np", + width=512, + height=512, + sld_guidance_scale=0, + ) + + image = output.images + image_slice = image[0, -3:, -3:, -1] + expected_slice = [0.3502, 0.3622, 0.3396, 0.3642, 0.3478, 0.3318, 0.35, 0.3348, 0.3297] + + assert image.shape == (1, 512, 512, 3) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + generator = torch.Generator(device=torch_device).manual_seed(seed) + output = sd_pipe( + [prompt], + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=50, + output_type="np", + width=512, + height=512, + sld_guidance_scale=2000, + sld_warmup_steps=7, + sld_threshold=0.025, + sld_momentum_scale=0.5, + sld_mom_beta=0.7, + ) + + image = output.images + image_slice = image[0, -3:, -3:, -1] + expected_slice = [0.5531, 0.5206, 0.4895, 0.5156, 0.5182, 0.4751, 0.4802, 0.4803, 0.4443] + + assert image.shape == (1, 512, 512, 3) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_nudity_safetychecker_safe_stable_diffusion(self): + sd_pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = ( + "the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c." + " leyendecker" + ) + seed = 1044355234 + guidance_scale = 12 + + generator = torch.Generator(device=torch_device).manual_seed(seed) + output = sd_pipe( + [prompt], + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=50, + output_type="np", + width=512, + height=512, + sld_guidance_scale=0, + ) + + image = output.images + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + + assert image.shape == (1, 512, 512, 3) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-7 + + generator = torch.Generator(device=torch_device).manual_seed(seed) + output = sd_pipe( + [prompt], + generator=generator, + guidance_scale=guidance_scale, + num_inference_steps=50, + output_type="np", + width=512, + height=512, + sld_guidance_scale=2000, + sld_warmup_steps=7, + sld_threshold=0.025, + sld_momentum_scale=0.5, + sld_mom_beta=0.7, + ) + + image = output.images + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.5818, 0.6285, 0.6835, 0.6019, 0.625, 0.6754, 0.6096, 0.6334, 0.6561]) + assert image.shape == (1, 512, 512, 3) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/versatile_diffusion/__init__.py b/tests/pipelines/versatile_diffusion/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/pipelines/versatile_diffusion/test_versatile_diffusion_dual_guided.py b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_dual_guided.py new file mode 100644 index 0000000000..9fb6ca522f --- /dev/null +++ b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_dual_guided.py @@ -0,0 +1,112 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest + +import numpy as np +import torch + +from diffusers import VersatileDiffusionDualGuidedPipeline +from diffusers.utils.testing_utils import load_image, require_torch_gpu, slow, torch_device + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class VersatileDiffusionDualGuidedPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pass + + +@slow +@require_torch_gpu +class VersatileDiffusionDualGuidedPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_remove_unused_weights_save_load(self): + pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained("shi-labs/versatile-diffusion") + # remove text_unet + pipe.remove_unused_weights() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + second_prompt = load_image( + "https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg" + ) + + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe( + prompt="first prompt", + image=second_prompt, + text_to_image_strength=0.75, + generator=generator, + guidance_scale=7.5, + num_inference_steps=2, + output_type="numpy", + ).images + + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained(tmpdirname) + + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator = generator.manual_seed(0) + new_image = pipe( + prompt="first prompt", + image=second_prompt, + text_to_image_strength=0.75, + generator=generator, + guidance_scale=7.5, + num_inference_steps=2, + output_type="numpy", + ).images + + assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass" + + def test_inference_dual_guided(self): + pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained("shi-labs/versatile-diffusion") + pipe.remove_unused_weights() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + first_prompt = "cyberpunk 2077" + second_prompt = load_image( + "https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg" + ) + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe( + prompt=first_prompt, + image=second_prompt, + text_to_image_strength=0.75, + generator=generator, + guidance_scale=7.5, + num_inference_steps=50, + output_type="numpy", + ).images + + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.014, 0.0112, 0.0136, 0.0145, 0.0107, 0.0113, 0.0272, 0.0215, 0.0216]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/versatile_diffusion/test_versatile_diffusion_image_variation.py b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_image_variation.py new file mode 100644 index 0000000000..1711b75299 --- /dev/null +++ b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_image_variation.py @@ -0,0 +1,58 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from diffusers import VersatileDiffusionImageVariationPipeline +from diffusers.utils.testing_utils import load_image, require_torch_gpu, slow, torch_device + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class VersatileDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pass + + +@slow +@require_torch_gpu +class VersatileDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase): + def test_inference_image_variations(self): + pipe = VersatileDiffusionImageVariationPipeline.from_pretrained("shi-labs/versatile-diffusion") + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + image_prompt = load_image( + "https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg" + ) + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe( + image=image_prompt, + generator=generator, + guidance_scale=7.5, + num_inference_steps=50, + output_type="numpy", + ).images + + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.1205, 0.1914, 0.2289, 0.0883, 0.1595, 0.1683, 0.0703, 0.1493, 0.1298]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py new file mode 100644 index 0000000000..9387d141d1 --- /dev/null +++ b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py @@ -0,0 +1,128 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest + +import numpy as np +import torch + +from diffusers import VersatileDiffusionPipeline +from diffusers.utils.testing_utils import load_image, require_torch_gpu, slow, torch_device + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class VersatileDiffusionMegaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pass + + +@slow +@require_torch_gpu +class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_from_pretrained_save_pretrained(self): + pipe = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", torch_dtype=torch.float16) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt_image = load_image( + "https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg" + ) + + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe.dual_guided( + prompt="first prompt", + image=prompt_image, + text_to_image_strength=0.75, + generator=generator, + guidance_scale=7.5, + num_inference_steps=2, + output_type="numpy", + ).images + + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + pipe = VersatileDiffusionPipeline.from_pretrained(tmpdirname, torch_dtype=torch.float16) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator = generator.manual_seed(0) + new_image = pipe.dual_guided( + prompt="first prompt", + image=prompt_image, + text_to_image_strength=0.75, + generator=generator, + guidance_scale=7.5, + num_inference_steps=2, + output_type="numpy", + ).images + + assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass" + + def test_inference_dual_guided_then_text_to_image(self): + pipe = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", torch_dtype=torch.float16) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "cyberpunk 2077" + init_image = load_image( + "https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg" + ) + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe.dual_guided( + prompt=prompt, + image=init_image, + text_to_image_strength=0.75, + generator=generator, + guidance_scale=7.5, + num_inference_steps=50, + output_type="numpy", + ).images + + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.0081, 0.0032, 0.0002, 0.0056, 0.0027, 0.0000, 0.0051, 0.0020, 0.0007]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2 + + prompt = "A painting of a squirrel eating a burger " + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe.text_to_image( + prompt=prompt, generator=generator, guidance_scale=7.5, num_inference_steps=50, output_type="numpy" + ).images + + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.0408, 0.0181, 0.0, 0.0388, 0.0046, 0.0461, 0.0411, 0.0, 0.0222]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2 + + image = pipe.image_variation(init_image, generator=generator, output_type="numpy").images + + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.3403, 0.1809, 0.0938, 0.3855, 0.2393, 0.1243, 0.4028, 0.3110, 0.1799]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2 diff --git a/tests/pipelines/versatile_diffusion/test_versatile_diffusion_text_to_image.py b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_text_to_image.py new file mode 100644 index 0000000000..027819efee --- /dev/null +++ b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_text_to_image.py @@ -0,0 +1,86 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest + +import numpy as np +import torch + +from diffusers import VersatileDiffusionTextToImagePipeline +from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class VersatileDiffusionTextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pass + + +@slow +@require_torch_gpu +class VersatileDiffusionTextToImagePipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_remove_unused_weights_save_load(self): + pipe = VersatileDiffusionTextToImagePipeline.from_pretrained("shi-labs/versatile-diffusion") + # remove text_unet + pipe.remove_unused_weights() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger " + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe( + prompt=prompt, generator=generator, guidance_scale=7.5, num_inference_steps=2, output_type="numpy" + ).images + + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + pipe = VersatileDiffusionTextToImagePipeline.from_pretrained(tmpdirname) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator = generator.manual_seed(0) + new_image = pipe( + prompt=prompt, generator=generator, guidance_scale=7.5, num_inference_steps=2, output_type="numpy" + ).images + + assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass" + + def test_inference_text2img(self): + pipe = VersatileDiffusionTextToImagePipeline.from_pretrained("shi-labs/versatile-diffusion") + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger " + generator = torch.Generator(device=torch_device).manual_seed(0) + image = pipe( + prompt=prompt, generator=generator, guidance_scale=7.5, num_inference_steps=50, output_type="numpy" + ).images + + image_slice = image[0, 253:256, 253:256, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.0408, 0.0181, 0.0, 0.0388, 0.0046, 0.0461, 0.0411, 0.0, 0.0222]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/vq_diffusion/test_vq_diffusion.py b/tests/pipelines/vq_diffusion/test_vq_diffusion.py index 5eb32d40d4..87e29cbc97 100644 --- a/tests/pipelines/vq_diffusion/test_vq_diffusion.py +++ b/tests/pipelines/vq_diffusion/test_vq_diffusion.py @@ -20,7 +20,8 @@ import numpy as np import torch from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel -from diffusers.utils import load_image, slow, torch_device +from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings +from diffusers.utils import load_numpy, slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer @@ -45,6 +46,10 @@ class VQDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): def num_embeds_ada_norm(self): return 12 + @property + def text_embedder_hidden_size(self): + return 32 + @property def dummy_vqvae(self): torch.manual_seed(0) @@ -71,7 +76,7 @@ class VQDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): config = CLIPTextConfig( bos_token_id=0, eos_token_id=2, - hidden_size=32, + hidden_size=self.text_embedder_hidden_size, intermediate_size=37, layer_norm_eps=1e-05, num_attention_heads=4, @@ -111,9 +116,15 @@ class VQDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): tokenizer = self.dummy_tokenizer transformer = self.dummy_transformer scheduler = VQDiffusionScheduler(self.num_embed) + learned_classifier_free_sampling_embeddings = LearnedClassifierFreeSamplingEmbeddings(learnable=False) pipe = VQDiffusionPipeline( - vqvae=vqvae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer, scheduler=scheduler + vqvae=vqvae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings, ) pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None) @@ -139,6 +150,50 @@ class VQDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + def test_vq_diffusion_classifier_free_sampling(self): + device = "cpu" + + vqvae = self.dummy_vqvae + text_encoder = self.dummy_text_encoder + tokenizer = self.dummy_tokenizer + transformer = self.dummy_transformer + scheduler = VQDiffusionScheduler(self.num_embed) + learned_classifier_free_sampling_embeddings = LearnedClassifierFreeSamplingEmbeddings( + learnable=True, hidden_size=self.text_embedder_hidden_size, length=tokenizer.model_max_length + ) + + pipe = VQDiffusionPipeline( + vqvae=vqvae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings, + ) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + prompt = "teddy bear playing in the pool" + + generator = torch.Generator(device=device).manual_seed(0) + output = pipe([prompt], generator=generator, num_inference_steps=2, output_type="np") + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = pipe( + [prompt], generator=generator, output_type="np", return_dict=False, num_inference_steps=2 + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 24, 24, 3) + + expected_slice = np.array([0.6647, 0.6531, 0.5303, 0.5891, 0.5726, 0.4439, 0.6304, 0.5564, 0.4912]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + @slow @require_torch_gpu @@ -149,12 +204,11 @@ class VQDiffusionPipelineIntegrationTests(unittest.TestCase): gc.collect() torch.cuda.empty_cache() - def test_vq_diffusion(self): - expected_image = load_image( + def test_vq_diffusion_classifier_free_sampling(self): + expected_image = load_numpy( "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/vq_diffusion/teddy_bear_pool.png" + "/vq_diffusion/teddy_bear_pool_classifier_free_sampling.npy" ) - expected_image = np.array(expected_image, dtype=np.float32) / 255.0 pipeline = VQDiffusionPipeline.from_pretrained("microsoft/vq-diffusion-ithq") pipeline = pipeline.to(torch_device) @@ -163,7 +217,6 @@ class VQDiffusionPipelineIntegrationTests(unittest.TestCase): generator = torch.Generator(device=torch_device).manual_seed(0) output = pipeline( "teddy bear playing in the pool", - truncation_rate=0.86, num_images_per_prompt=1, generator=generator, output_type="np", diff --git a/tests/test_config.py b/tests/test_config.py index 8ae8e1d9e1..2a021c4ced 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -13,12 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json -import os import tempfile import unittest -import diffusers from diffusers import ( DDIMScheduler, DDPMScheduler, @@ -29,6 +26,7 @@ from diffusers import ( logging, ) from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import deprecate from diffusers.utils.testing_utils import CaptureLogger @@ -81,7 +79,7 @@ class SampleObject3(ConfigMixin): class ConfigTester(unittest.TestCase): def test_load_not_from_mixin(self): with self.assertRaises(ValueError): - ConfigMixin.from_config("dummy_path") + ConfigMixin.load_config("dummy_path") def test_register_to_config(self): obj = SampleObject() @@ -131,7 +129,7 @@ class ConfigTester(unittest.TestCase): with tempfile.TemporaryDirectory() as tmpdirname: obj.save_config(tmpdirname) - new_obj = SampleObject.from_config(tmpdirname) + new_obj = SampleObject.from_config(SampleObject.load_config(tmpdirname)) new_config = new_obj.config # unfreeze configs @@ -142,117 +140,13 @@ class ConfigTester(unittest.TestCase): assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json assert config == new_config - def test_save_load_from_different_config(self): - obj = SampleObject() - - # mock add obj class to `diffusers` - setattr(diffusers, "SampleObject", SampleObject) - logger = logging.get_logger("diffusers.configuration_utils") - - with tempfile.TemporaryDirectory() as tmpdirname: - obj.save_config(tmpdirname) - with CaptureLogger(logger) as cap_logger_1: - new_obj_1 = SampleObject2.from_config(tmpdirname) - - # now save a config parameter that is not expected - with open(os.path.join(tmpdirname, SampleObject.config_name), "r") as f: - data = json.load(f) - data["unexpected"] = True - - with open(os.path.join(tmpdirname, SampleObject.config_name), "w") as f: - json.dump(data, f) - - with CaptureLogger(logger) as cap_logger_2: - new_obj_2 = SampleObject.from_config(tmpdirname) - - with CaptureLogger(logger) as cap_logger_3: - new_obj_3 = SampleObject2.from_config(tmpdirname) - - assert new_obj_1.__class__ == SampleObject2 - assert new_obj_2.__class__ == SampleObject - assert new_obj_3.__class__ == SampleObject2 - - assert cap_logger_1.out == "" - assert ( - cap_logger_2.out - == "The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will" - " be ignored. Please verify your config.json configuration file.\n" - ) - assert cap_logger_2.out.replace("SampleObject", "SampleObject2") == cap_logger_3.out - - def test_save_load_compatible_schedulers(self): - SampleObject2._compatible_classes = ["SampleObject"] - SampleObject._compatible_classes = ["SampleObject2"] - - obj = SampleObject() - - # mock add obj class to `diffusers` - setattr(diffusers, "SampleObject", SampleObject) - setattr(diffusers, "SampleObject2", SampleObject2) - logger = logging.get_logger("diffusers.configuration_utils") - - with tempfile.TemporaryDirectory() as tmpdirname: - obj.save_config(tmpdirname) - - # now save a config parameter that is expected by another class, but not origin class - with open(os.path.join(tmpdirname, SampleObject.config_name), "r") as f: - data = json.load(f) - data["f"] = [0, 0] - data["unexpected"] = True - - with open(os.path.join(tmpdirname, SampleObject.config_name), "w") as f: - json.dump(data, f) - - with CaptureLogger(logger) as cap_logger: - new_obj = SampleObject.from_config(tmpdirname) - - assert new_obj.__class__ == SampleObject - - assert ( - cap_logger.out - == "The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will" - " be ignored. Please verify your config.json configuration file.\n" - ) - - def test_save_load_from_different_config_comp_schedulers(self): - SampleObject3._compatible_classes = ["SampleObject", "SampleObject2"] - SampleObject2._compatible_classes = ["SampleObject", "SampleObject3"] - SampleObject._compatible_classes = ["SampleObject2", "SampleObject3"] - - obj = SampleObject() - - # mock add obj class to `diffusers` - setattr(diffusers, "SampleObject", SampleObject) - setattr(diffusers, "SampleObject2", SampleObject2) - setattr(diffusers, "SampleObject3", SampleObject3) - logger = logging.get_logger("diffusers.configuration_utils") - logger.setLevel(diffusers.logging.INFO) - - with tempfile.TemporaryDirectory() as tmpdirname: - obj.save_config(tmpdirname) - - with CaptureLogger(logger) as cap_logger_1: - new_obj_1 = SampleObject.from_config(tmpdirname) - - with CaptureLogger(logger) as cap_logger_2: - new_obj_2 = SampleObject2.from_config(tmpdirname) - - with CaptureLogger(logger) as cap_logger_3: - new_obj_3 = SampleObject3.from_config(tmpdirname) - - assert new_obj_1.__class__ == SampleObject - assert new_obj_2.__class__ == SampleObject2 - assert new_obj_3.__class__ == SampleObject3 - - assert cap_logger_1.out == "" - assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n" - assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n" - def test_load_ddim_from_pndm(self): logger = logging.get_logger("diffusers.configuration_utils") with CaptureLogger(logger) as cap_logger: - ddim = DDIMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler") + ddim = DDIMScheduler.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" + ) assert ddim.__class__ == DDIMScheduler # no warning should be thrown @@ -262,7 +156,7 @@ class ConfigTester(unittest.TestCase): logger = logging.get_logger("diffusers.configuration_utils") with CaptureLogger(logger) as cap_logger: - euler = EulerDiscreteScheduler.from_config( + euler = EulerDiscreteScheduler.from_pretrained( "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" ) @@ -274,7 +168,7 @@ class ConfigTester(unittest.TestCase): logger = logging.get_logger("diffusers.configuration_utils") with CaptureLogger(logger) as cap_logger: - euler = EulerAncestralDiscreteScheduler.from_config( + euler = EulerAncestralDiscreteScheduler.from_pretrained( "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" ) @@ -286,7 +180,9 @@ class ConfigTester(unittest.TestCase): logger = logging.get_logger("diffusers.configuration_utils") with CaptureLogger(logger) as cap_logger: - pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler") + pndm = PNDMScheduler.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" + ) assert pndm.__class__ == PNDMScheduler # no warning should be thrown @@ -296,20 +192,30 @@ class ConfigTester(unittest.TestCase): logger = logging.get_logger("diffusers.configuration_utils") with CaptureLogger(logger) as cap_logger: - ddpm = DDPMScheduler.from_config( + ddpm = DDPMScheduler.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", + subfolder="scheduler", + prediction_type="sample", + beta_end=8, + ) + + with CaptureLogger(logger) as cap_logger_2: + ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88) + + with CaptureLogger(logger) as cap_logger: + deprecate("remove this case", "0.10.0", "remove") + ddpm_3 = DDPMScheduler.from_pretrained( "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler", predict_epsilon=False, beta_end=8, ) - with CaptureLogger(logger) as cap_logger_2: - ddpm_2 = DDPMScheduler.from_config("google/ddpm-celebahq-256", beta_start=88) - assert ddpm.__class__ == DDPMScheduler - assert ddpm.config.predict_epsilon is False + assert ddpm.config.prediction_type == "sample" assert ddpm.config.beta_end == 8 assert ddpm_2.config.beta_start == 88 + assert ddpm_3.config.prediction_type == "sample" # no warning should be thrown assert cap_logger.out == "" @@ -319,7 +225,7 @@ class ConfigTester(unittest.TestCase): logger = logging.get_logger("diffusers.configuration_utils") with CaptureLogger(logger) as cap_logger: - dpm = DPMSolverMultistepScheduler.from_config( + dpm = DPMSolverMultistepScheduler.from_pretrained( "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" ) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index eabe6ada9f..cad1887f4d 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -130,7 +130,7 @@ class ModelTesterMixin: expected_arg_names = ["sample", "timestep"] self.assertListEqual(arg_names[:2], expected_arg_names) - def test_model_from_config(self): + def test_model_from_pretrained(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) @@ -140,8 +140,8 @@ class ModelTesterMixin: # test if the model can be loaded from the config # and has all the expected shape with tempfile.TemporaryDirectory() as tmpdirname: - model.save_config(tmpdirname) - new_model = self.model_class.from_config(tmpdirname) + model.save_pretrained(tmpdirname) + new_model = self.model_class.from_pretrained(tmpdirname) new_model.to(torch_device) new_model.eval() @@ -265,3 +265,23 @@ class ModelTesterMixin: # check disable works model.disable_gradient_checkpointing() self.assertFalse(model.is_gradient_checkpointing) + + def test_deprecated_kwargs(self): + has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters + has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0 + + if has_kwarg_in_model_class and not has_deprecated_kwarg: + raise ValueError( + f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs" + " under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are" + " no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" + " []`" + ) + + if not has_kwarg_in_model_class and has_deprecated_kwarg: + raise ValueError( + f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs" + " under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to" + f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument" + " from `_deprecated_kwargs = []`" + ) diff --git a/tests/test_modeling_common_flax.py b/tests/test_modeling_common_flax.py index 61849b2231..8945aed7c9 100644 --- a/tests/test_modeling_common_flax.py +++ b/tests/test_modeling_common_flax.py @@ -1,3 +1,5 @@ +import inspect + from diffusers.utils import is_flax_available from diffusers.utils.testing_utils import require_flax @@ -42,3 +44,23 @@ class FlaxModelTesterMixin: self.assertIsNotNone(output) expected_shape = inputs_dict["sample"].shape self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_deprecated_kwargs(self): + has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters + has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0 + + if has_kwarg_in_model_class and not has_deprecated_kwarg: + raise ValueError( + f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs" + " under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are" + " no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" + " []`" + ) + + if not has_kwarg_in_model_class and has_deprecated_kwarg: + raise ValueError( + f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs" + " under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to" + f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument" + " from `_deprecated_kwargs = []`" + ) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 4559d713ed..6ae11e122d 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -14,8 +14,10 @@ # limitations under the License. import gc +import json import os import random +import shutil import tempfile import unittest @@ -29,19 +31,23 @@ from diffusers import ( DDIMScheduler, DDPMPipeline, DDPMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, PNDMScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipelineLegacy, StableDiffusionPipeline, UNet2DConditionModel, UNet2DModel, - VQModel, logging, ) from diffusers.pipeline_utils import DiffusionPipeline from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu +from parameterized import parameterized from PIL import Image from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer @@ -86,6 +92,24 @@ class DownloadTests(unittest.TestCase): # None of the downloaded files should be a flax file even if we have some here: # https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack assert not any(f.endswith(".msgpack") for f in files) + # We need to never convert this tiny model to safetensors for this test to pass + assert not any(f.endswith(".safetensors") for f in files) + + def test_download_safetensors(self): + with tempfile.TemporaryDirectory() as tmpdirname: + # pipeline has Flax weights + _ = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-pipe-safetensors", + safety_checker=None, + cache_dir=tmpdirname, + ) + + all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))] + files = [item for sublist in all_root_files for item in sublist] + + # None of the downloaded files should be a pytorch file even if we have some here: + # https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack + assert not any(f.endswith(".bin") for f in files) def test_download_no_safety_checker(self): prompt = "hello" @@ -188,7 +212,7 @@ class CustomPipelineTests(unittest.TestCase): # compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102 assert output_str == "This is a test" - def test_local_custom_pipeline(self): + def test_local_custom_pipeline_repo(self): local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline") pipeline = DiffusionPipeline.from_pretrained( "google/ddpm-cifar10-32", custom_pipeline=local_custom_pipeline_path @@ -201,6 +225,20 @@ class CustomPipelineTests(unittest.TestCase): # compare to https://github.com/huggingface/diffusers/blob/main/tests/fixtures/custom_pipeline/pipeline.py#L102 assert output_str == "This is a local test" + def test_local_custom_pipeline_file(self): + local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline") + local_custom_pipeline_path = os.path.join(local_custom_pipeline_path, "what_ever.py") + pipeline = DiffusionPipeline.from_pretrained( + "google/ddpm-cifar10-32", custom_pipeline=local_custom_pipeline_path + ) + pipeline = pipeline.to(torch_device) + images, output_str = pipeline(num_inference_steps=2, output_type="np") + + assert pipeline.__class__.__name__ == "CustomLocalPipeline" + assert images[0].shape == (1, 32, 32, 3) + # compare to https://github.com/huggingface/diffusers/blob/main/tests/fixtures/custom_pipeline/pipeline.py#L102 + assert output_str == "This is a local test" + @slow @require_torch_gpu def test_load_pipeline_from_git(self): @@ -229,7 +267,6 @@ class CustomPipelineTests(unittest.TestCase): class PipelineFastTests(unittest.TestCase): - @property def dummy_image(self): batch_size = 1 num_channels = 3 @@ -238,13 +275,12 @@ class PipelineFastTests(unittest.TestCase): image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) return image - @property - def dummy_uncond_unet(self): + def dummy_uncond_unet(self, sample_size=32): torch.manual_seed(0) model = UNet2DModel( block_out_channels=(32, 64), layers_per_block=2, - sample_size=32, + sample_size=sample_size, in_channels=3, out_channels=3, down_block_types=("DownBlock2D", "AttnDownBlock2D"), @@ -252,13 +288,12 @@ class PipelineFastTests(unittest.TestCase): ) return model - @property - def dummy_cond_unet(self): + def dummy_cond_unet(self, sample_size=32): torch.manual_seed(0) model = UNet2DConditionModel( block_out_channels=(32, 64), layers_per_block=2, - sample_size=32, + sample_size=sample_size, in_channels=4, out_channels=4, down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), @@ -267,34 +302,6 @@ class PipelineFastTests(unittest.TestCase): ) return model - @property - def dummy_cond_unet_inpaint(self): - torch.manual_seed(0) - model = UNet2DConditionModel( - block_out_channels=(32, 64), - layers_per_block=2, - sample_size=32, - in_channels=9, - out_channels=4, - down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), - up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), - cross_attention_dim=32, - ) - return model - - @property - def dummy_vq_model(self): - torch.manual_seed(0) - model = VQModel( - block_out_channels=[32, 64], - in_channels=3, - out_channels=3, - down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], - up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], - latent_channels=3, - ) - return model - @property def dummy_vae(self): torch.manual_seed(0) @@ -339,17 +346,44 @@ class PipelineFastTests(unittest.TestCase): return extract - def test_components(self): + @parameterized.expand( + [ + [DDIMScheduler, DDIMPipeline, 32], + [DDPMScheduler, DDPMPipeline, 32], + [DDIMScheduler, DDIMPipeline, (32, 64)], + [DDPMScheduler, DDPMPipeline, (64, 32)], + ] + ) + def test_uncond_unet_components(self, scheduler_fn=DDPMScheduler, pipeline_fn=DDPMPipeline, sample_size=32): + unet = self.dummy_uncond_unet(sample_size) + scheduler = scheduler_fn() + pipeline = pipeline_fn(unet, scheduler).to(torch_device) + + # Device type MPS is not supported for torch.Generator() api. + if torch_device == "mps": + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) + + out_image = pipeline( + generator=generator, + num_inference_steps=2, + output_type="np", + ).images + sample_size = (sample_size, sample_size) if isinstance(sample_size, int) else sample_size + assert out_image.shape == (1, *sample_size, 3) + + def test_stable_diffusion_components(self): """Test that components property works correctly""" - unet = self.dummy_cond_unet + unet = self.dummy_cond_unet() scheduler = PNDMScheduler(skip_prk_steps=True) vae = self.dummy_vae bert = self.dummy_text_encoder tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] + image = self.dummy_image().cpu().permute(0, 2, 3, 1)[0] init_image = Image.fromarray(np.uint8(image)).convert("RGB") - mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32)) # make sure here that pndm scheduler skips prk inpaint = StableDiffusionInpaintPipelineLegacy( @@ -396,7 +430,187 @@ class PipelineFastTests(unittest.TestCase): assert image_inpaint.shape == (1, 32, 32, 3) assert image_img2img.shape == (1, 32, 32, 3) - assert image_text2img.shape == (1, 128, 128, 3) + assert image_text2img.shape == (1, 64, 64, 3) + + def test_set_scheduler(self): + unet = self.dummy_cond_unet() + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + sd = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + + sd.scheduler = DDIMScheduler.from_config(sd.scheduler.config) + assert isinstance(sd.scheduler, DDIMScheduler) + sd.scheduler = DDPMScheduler.from_config(sd.scheduler.config) + assert isinstance(sd.scheduler, DDPMScheduler) + sd.scheduler = PNDMScheduler.from_config(sd.scheduler.config) + assert isinstance(sd.scheduler, PNDMScheduler) + sd.scheduler = LMSDiscreteScheduler.from_config(sd.scheduler.config) + assert isinstance(sd.scheduler, LMSDiscreteScheduler) + sd.scheduler = EulerDiscreteScheduler.from_config(sd.scheduler.config) + assert isinstance(sd.scheduler, EulerDiscreteScheduler) + sd.scheduler = EulerAncestralDiscreteScheduler.from_config(sd.scheduler.config) + assert isinstance(sd.scheduler, EulerAncestralDiscreteScheduler) + sd.scheduler = DPMSolverMultistepScheduler.from_config(sd.scheduler.config) + assert isinstance(sd.scheduler, DPMSolverMultistepScheduler) + + def test_set_scheduler_consistency(self): + unet = self.dummy_cond_unet() + pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler") + ddim = DDIMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + sd = StableDiffusionPipeline( + unet=unet, + scheduler=pndm, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + + pndm_config = sd.scheduler.config + sd.scheduler = DDPMScheduler.from_config(pndm_config) + sd.scheduler = PNDMScheduler.from_config(sd.scheduler.config) + pndm_config_2 = sd.scheduler.config + pndm_config_2 = {k: v for k, v in pndm_config_2.items() if k in pndm_config} + + assert dict(pndm_config) == dict(pndm_config_2) + + sd = StableDiffusionPipeline( + unet=unet, + scheduler=ddim, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + + ddim_config = sd.scheduler.config + sd.scheduler = LMSDiscreteScheduler.from_config(ddim_config) + sd.scheduler = DDIMScheduler.from_config(sd.scheduler.config) + ddim_config_2 = sd.scheduler.config + ddim_config_2 = {k: v for k, v in ddim_config_2.items() if k in ddim_config} + + assert dict(ddim_config) == dict(ddim_config_2) + + def test_optional_components(self): + unet = self.dummy_cond_unet() + pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + orig_sd = StableDiffusionPipeline( + unet=unet, + scheduler=pndm, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=unet, + feature_extractor=self.dummy_extractor, + ) + sd = orig_sd + + assert sd.config.requires_safety_checker is True + + with tempfile.TemporaryDirectory() as tmpdirname: + sd.save_pretrained(tmpdirname) + + # Test that passing None works + sd = StableDiffusionPipeline.from_pretrained( + tmpdirname, feature_extractor=None, safety_checker=None, requires_safety_checker=False + ) + + assert sd.config.requires_safety_checker is False + assert sd.config.safety_checker == (None, None) + assert sd.config.feature_extractor == (None, None) + + with tempfile.TemporaryDirectory() as tmpdirname: + sd.save_pretrained(tmpdirname) + + # Test that loading previous None works + sd = StableDiffusionPipeline.from_pretrained(tmpdirname) + + assert sd.config.requires_safety_checker is False + assert sd.config.safety_checker == (None, None) + assert sd.config.feature_extractor == (None, None) + + orig_sd.save_pretrained(tmpdirname) + + # Test that loading without any directory works + shutil.rmtree(os.path.join(tmpdirname, "safety_checker")) + with open(os.path.join(tmpdirname, sd.config_name)) as f: + config = json.load(f) + config["safety_checker"] = [None, None] + with open(os.path.join(tmpdirname, sd.config_name), "w") as f: + json.dump(config, f) + + sd = StableDiffusionPipeline.from_pretrained(tmpdirname, requires_safety_checker=False) + sd.save_pretrained(tmpdirname) + sd = StableDiffusionPipeline.from_pretrained(tmpdirname) + + assert sd.config.requires_safety_checker is False + assert sd.config.safety_checker == (None, None) + assert sd.config.feature_extractor == (None, None) + + # Test that loading from deleted model index works + with open(os.path.join(tmpdirname, sd.config_name)) as f: + config = json.load(f) + del config["safety_checker"] + del config["feature_extractor"] + with open(os.path.join(tmpdirname, sd.config_name), "w") as f: + json.dump(config, f) + + sd = StableDiffusionPipeline.from_pretrained(tmpdirname) + + assert sd.config.requires_safety_checker is False + assert sd.config.safety_checker == (None, None) + assert sd.config.feature_extractor == (None, None) + + with tempfile.TemporaryDirectory() as tmpdirname: + sd.save_pretrained(tmpdirname) + + # Test that partially loading works + sd = StableDiffusionPipeline.from_pretrained(tmpdirname, feature_extractor=self.dummy_extractor) + + assert sd.config.requires_safety_checker is False + assert sd.config.safety_checker == (None, None) + assert sd.config.feature_extractor != (None, None) + + # Test that partially loading works + sd = StableDiffusionPipeline.from_pretrained( + tmpdirname, + feature_extractor=self.dummy_extractor, + safety_checker=unet, + requires_safety_checker=[True, True], + ) + + assert sd.config.requires_safety_checker == [True, True] + assert sd.config.safety_checker != (None, None) + assert sd.config.feature_extractor != (None, None) + + with tempfile.TemporaryDirectory() as tmpdirname: + sd.save_pretrained(tmpdirname) + sd = StableDiffusionPipeline.from_pretrained(tmpdirname, feature_extractor=self.dummy_extractor) + + assert sd.config.requires_safety_checker == [True, True] + assert sd.config.safety_checker != (None, None) + assert sd.config.feature_extractor != (None, None) @slow @@ -440,7 +654,10 @@ class PipelineSlowTests(unittest.TestCase): force_download=True, ) - assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n" + assert ( + cap_logger.out + == "Keyword arguments {'not_used': True} are not expected by DDPMPipeline and will be ignored.\n" + ) def test_from_pretrained_save_pretrained(self): # 1. Load models @@ -519,7 +736,7 @@ class PipelineSlowTests(unittest.TestCase): def test_output_format(self): model_path = "google/ddpm-cifar10-32" - scheduler = DDIMScheduler.from_config(model_path) + scheduler = DDIMScheduler.from_pretrained(model_path) pipe = DDIMPipeline.from_pretrained(model_path, scheduler=scheduler) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) diff --git a/tests/test_pipelines_flax.py b/tests/test_pipelines_flax.py index 72316aad92..9b9dcddd60 100644 --- a/tests/test_pipelines_flax.py +++ b/tests/test_pipelines_flax.py @@ -78,7 +78,7 @@ class FlaxPipelineTests(unittest.TestCase): images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images - assert images.shape == (num_samples, 1, 128, 128, 3) + assert images.shape == (num_samples, 1, 64, 64, 3) if jax.device_count() == 8: assert np.abs(np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 3.1111548) < 1e-3 assert np.abs(np.abs(images, dtype=np.float32).sum() - 199746.95) < 5e-1 diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index a9770f0a54..f840f8ce97 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +import json +import os import tempfile import unittest from typing import Dict, List, Tuple @@ -21,24 +23,193 @@ import numpy as np import torch import torch.nn.functional as F +import diffusers from diffusers import ( DDIMScheduler, DDPMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, + HeunDiscreteScheduler, IPNDMScheduler, LMSDiscreteScheduler, PNDMScheduler, ScoreSdeVeScheduler, VQDiffusionScheduler, + logging, ) +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.utils import deprecate, torch_device +from diffusers.utils.testing_utils import CaptureLogger torch.backends.cuda.matmul.allow_tf32 = False +class SchedulerObject(SchedulerMixin, ConfigMixin): + config_name = "config.json" + + @register_to_config + def __init__( + self, + a=2, + b=5, + c=(2, 5), + d="for diffusion", + e=[1, 3], + ): + pass + + +class SchedulerObject2(SchedulerMixin, ConfigMixin): + config_name = "config.json" + + @register_to_config + def __init__( + self, + a=2, + b=5, + c=(2, 5), + d="for diffusion", + f=[1, 3], + ): + pass + + +class SchedulerObject3(SchedulerMixin, ConfigMixin): + config_name = "config.json" + + @register_to_config + def __init__( + self, + a=2, + b=5, + c=(2, 5), + d="for diffusion", + e=[1, 3], + f=[1, 3], + ): + pass + + +class SchedulerBaseTests(unittest.TestCase): + def test_save_load_from_different_config(self): + obj = SchedulerObject() + + # mock add obj class to `diffusers` + setattr(diffusers, "SchedulerObject", SchedulerObject) + logger = logging.get_logger("diffusers.configuration_utils") + + with tempfile.TemporaryDirectory() as tmpdirname: + obj.save_config(tmpdirname) + with CaptureLogger(logger) as cap_logger_1: + config = SchedulerObject2.load_config(tmpdirname) + new_obj_1 = SchedulerObject2.from_config(config) + + # now save a config parameter that is not expected + with open(os.path.join(tmpdirname, SchedulerObject.config_name), "r") as f: + data = json.load(f) + data["unexpected"] = True + + with open(os.path.join(tmpdirname, SchedulerObject.config_name), "w") as f: + json.dump(data, f) + + with CaptureLogger(logger) as cap_logger_2: + config = SchedulerObject.load_config(tmpdirname) + new_obj_2 = SchedulerObject.from_config(config) + + with CaptureLogger(logger) as cap_logger_3: + config = SchedulerObject2.load_config(tmpdirname) + new_obj_3 = SchedulerObject2.from_config(config) + + assert new_obj_1.__class__ == SchedulerObject2 + assert new_obj_2.__class__ == SchedulerObject + assert new_obj_3.__class__ == SchedulerObject2 + + assert cap_logger_1.out == "" + assert ( + cap_logger_2.out + == "The config attributes {'unexpected': True} were passed to SchedulerObject, but are not expected and" + " will" + " be ignored. Please verify your config.json configuration file.\n" + ) + assert cap_logger_2.out.replace("SchedulerObject", "SchedulerObject2") == cap_logger_3.out + + def test_save_load_compatible_schedulers(self): + SchedulerObject2._compatibles = ["SchedulerObject"] + SchedulerObject._compatibles = ["SchedulerObject2"] + + obj = SchedulerObject() + + # mock add obj class to `diffusers` + setattr(diffusers, "SchedulerObject", SchedulerObject) + setattr(diffusers, "SchedulerObject2", SchedulerObject2) + logger = logging.get_logger("diffusers.configuration_utils") + + with tempfile.TemporaryDirectory() as tmpdirname: + obj.save_config(tmpdirname) + + # now save a config parameter that is expected by another class, but not origin class + with open(os.path.join(tmpdirname, SchedulerObject.config_name), "r") as f: + data = json.load(f) + data["f"] = [0, 0] + data["unexpected"] = True + + with open(os.path.join(tmpdirname, SchedulerObject.config_name), "w") as f: + json.dump(data, f) + + with CaptureLogger(logger) as cap_logger: + config = SchedulerObject.load_config(tmpdirname) + new_obj = SchedulerObject.from_config(config) + + assert new_obj.__class__ == SchedulerObject + + assert ( + cap_logger.out + == "The config attributes {'unexpected': True} were passed to SchedulerObject, but are not expected and" + " will" + " be ignored. Please verify your config.json configuration file.\n" + ) + + def test_save_load_from_different_config_comp_schedulers(self): + SchedulerObject3._compatibles = ["SchedulerObject", "SchedulerObject2"] + SchedulerObject2._compatibles = ["SchedulerObject", "SchedulerObject3"] + SchedulerObject._compatibles = ["SchedulerObject2", "SchedulerObject3"] + + obj = SchedulerObject() + + # mock add obj class to `diffusers` + setattr(diffusers, "SchedulerObject", SchedulerObject) + setattr(diffusers, "SchedulerObject2", SchedulerObject2) + setattr(diffusers, "SchedulerObject3", SchedulerObject3) + logger = logging.get_logger("diffusers.configuration_utils") + logger.setLevel(diffusers.logging.INFO) + + with tempfile.TemporaryDirectory() as tmpdirname: + obj.save_config(tmpdirname) + + with CaptureLogger(logger) as cap_logger_1: + config = SchedulerObject.load_config(tmpdirname) + new_obj_1 = SchedulerObject.from_config(config) + + with CaptureLogger(logger) as cap_logger_2: + config = SchedulerObject2.load_config(tmpdirname) + new_obj_2 = SchedulerObject2.from_config(config) + + with CaptureLogger(logger) as cap_logger_3: + config = SchedulerObject3.load_config(tmpdirname) + new_obj_3 = SchedulerObject3.from_config(config) + + assert new_obj_1.__class__ == SchedulerObject + assert new_obj_2.__class__ == SchedulerObject2 + assert new_obj_3.__class__ == SchedulerObject3 + + assert cap_logger_1.out == "" + assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n" + assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n" + + class SchedulerCommonTest(unittest.TestCase): scheduler_classes = () forward_default_kwargs = () @@ -102,7 +273,7 @@ class SchedulerCommonTest(unittest.TestCase): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): scheduler.set_timesteps(num_inference_steps) @@ -145,7 +316,7 @@ class SchedulerCommonTest(unittest.TestCase): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): scheduler.set_timesteps(num_inference_steps) @@ -187,7 +358,7 @@ class SchedulerCommonTest(unittest.TestCase): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): scheduler.set_timesteps(num_inference_steps) @@ -205,6 +376,42 @@ class SchedulerCommonTest(unittest.TestCase): assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + def test_compatibles(self): + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + + scheduler = scheduler_class(**scheduler_config) + + assert all(c is not None for c in scheduler.compatibles) + + for comp_scheduler_cls in scheduler.compatibles: + comp_scheduler = comp_scheduler_cls.from_config(scheduler.config) + assert comp_scheduler is not None + + new_scheduler = scheduler_class.from_config(comp_scheduler.config) + + new_scheduler_config = {k: v for k, v in new_scheduler.config.items() if k in scheduler.config} + scheduler_diff = {k: v for k, v in new_scheduler.config.items() if k not in scheduler.config} + + # make sure that configs are essentially identical + assert new_scheduler_config == dict(scheduler.config) + + # make sure that only differences are for configs that are not in init + init_keys = inspect.signature(scheduler_class.__init__).parameters.keys() + assert set(scheduler_diff.keys()).intersection(set(init_keys)) == set() + + def test_from_pretrained(self): + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + + scheduler = scheduler_class(**scheduler_config) + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_pretrained(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) + + assert scheduler.config == new_scheduler.config + def test_step_shape(self): kwargs = dict(self.forward_default_kwargs) @@ -356,6 +563,27 @@ class SchedulerCommonTest(unittest.TestCase): noised = scheduler.add_noise(scaled_sample, noise, t) self.assertEqual(noised.shape, scaled_sample.shape) + def test_deprecated_kwargs(self): + for scheduler_class in self.scheduler_classes: + has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters + has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0 + + if has_kwarg_in_model_class and not has_deprecated_kwarg: + raise ValueError( + f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated" + " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if" + " there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" + " []`" + ) + + if not has_kwarg_in_model_class and has_deprecated_kwarg: + raise ValueError( + f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated" + " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`" + f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the" + " deprecated argument from `_deprecated_kwargs = []`" + ) + class DDPMSchedulerTest(SchedulerCommonTest): scheduler_classes = (DDPMScheduler,) @@ -393,7 +621,12 @@ class DDPMSchedulerTest(SchedulerCommonTest): for clip_sample in [True, False]: self.check_over_configs(clip_sample=clip_sample) - def test_predict_epsilon(self): + def test_prediction_type(self): + for prediction_type in ["epsilon", "sample"]: + self.check_over_configs(prediction_type=prediction_type) + + def test_deprecated_predict_epsilon(self): + deprecate("remove this test", "0.10.0", "remove") for predict_epsilon in [True, False]: self.check_over_configs(predict_epsilon=predict_epsilon) @@ -589,7 +822,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): "beta_end": 0.02, "beta_schedule": "linear", "solver_order": 2, - "predict_epsilon": True, + "prediction_type": "epsilon", "thresholding": False, "sample_max_value": 1.0, "algorithm_type": "dpmsolver++", @@ -616,7 +849,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) new_scheduler.set_timesteps(num_inference_steps) # copy over dummy past residuals new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order] @@ -648,7 +881,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) # copy over dummy past residuals new_scheduler.set_timesteps(num_inference_steps) @@ -715,10 +948,10 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): for order in [1, 2, 3]: for solver_type in ["midpoint", "heun"]: for threshold in [0.5, 1.0, 2.0]: - for predict_epsilon in [True, False]: + for prediction_type in ["epsilon", "sample"]: self.check_over_configs( thresholding=True, - predict_epsilon=predict_epsilon, + prediction_type=prediction_type, sample_max_value=threshold, algorithm_type="dpmsolver++", solver_order=order, @@ -729,17 +962,17 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): for algorithm_type in ["dpmsolver", "dpmsolver++"]: for solver_type in ["midpoint", "heun"]: for order in [1, 2, 3]: - for predict_epsilon in [True, False]: + for prediction_type in ["epsilon", "sample"]: self.check_over_configs( solver_order=order, solver_type=solver_type, - predict_epsilon=predict_epsilon, + prediction_type=prediction_type, algorithm_type=algorithm_type, ) sample = self.full_loop( solver_order=order, solver_type=solver_type, - predict_epsilon=predict_epsilon, + prediction_type=prediction_type, algorithm_type=algorithm_type, ) assert not torch.isnan(sample).any(), "Samples have nan numbers" @@ -758,6 +991,22 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): assert abs(result_mean.item() - 0.3301) < 1e-3 + def test_fp16_support(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0) + scheduler = scheduler_class(**scheduler_config) + + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter.half() + scheduler.set_timesteps(num_inference_steps) + + for i, t in enumerate(scheduler.timesteps): + residual = model(sample, t) + sample = scheduler.step(residual, t, sample).prev_sample + + assert sample.dtype == torch.float16 + class PNDMSchedulerTest(SchedulerCommonTest): scheduler_classes = (PNDMScheduler,) @@ -790,7 +1039,7 @@ class PNDMSchedulerTest(SchedulerCommonTest): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) new_scheduler.set_timesteps(num_inference_steps) # copy over dummy past residuals new_scheduler.ets = dummy_past_residuals[:] @@ -825,7 +1074,7 @@ class PNDMSchedulerTest(SchedulerCommonTest): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) # copy over dummy past residuals new_scheduler.set_timesteps(num_inference_steps) @@ -1043,7 +1292,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) output = scheduler.step_pred( residual, time_step, sample, generator=torch.manual_seed(0), **kwargs @@ -1074,7 +1323,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) output = scheduler.step_pred( residual, time_step, sample, generator=torch.manual_seed(0), **kwargs @@ -1470,7 +1719,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) new_scheduler.set_timesteps(num_inference_steps) # copy over dummy past residuals new_scheduler.ets = dummy_past_residuals[:] @@ -1508,7 +1757,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) # copy over dummy past residuals new_scheduler.set_timesteps(num_inference_steps) @@ -1644,3 +1893,95 @@ class VQDiffusionSchedulerTest(SchedulerCommonTest): def test_add_noise_device(self): pass + + +class HeunDiscreteSchedulerTest(SchedulerCommonTest): + scheduler_classes = (HeunDiscreteScheduler,) + num_inference_steps = 10 + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1100, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + "trained_betas": None, + } + + config.update(**kwargs) + return config + + def test_timesteps(self): + for timesteps in [10, 50, 100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_betas(self): + for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "scaled_linear"]: + self.check_over_configs(beta_schedule=schedule) + + def test_full_loop_no_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for i, t in enumerate(scheduler.timesteps): + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + if torch_device in ["cpu", "mps"]: + assert abs(result_sum.item() - 0.1233) < 1e-2 + assert abs(result_mean.item() - 0.0002) < 1e-3 + else: + # CUDA + assert abs(result_sum.item() - 0.1233) < 1e-2 + assert abs(result_mean.item() - 0.0002) < 1e-3 + + def test_full_loop_device(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps, device=torch_device) + + model = self.dummy_model() + sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma + + for t in scheduler.timesteps: + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + if str(torch_device).startswith("cpu"): + # The following sum varies between 148 and 156 on mps. Why? + assert abs(result_sum.item() - 0.1233) < 1e-2 + assert abs(result_mean.item() - 0.0002) < 1e-3 + elif str(torch_device).startswith("mps"): + # Larger tolerance on mps + assert abs(result_mean.item() - 0.0002) < 1e-2 + else: + # CUDA + assert abs(result_sum.item() - 0.1233) < 1e-2 + assert abs(result_mean.item() - 0.0002) < 1e-3 diff --git a/tests/test_scheduler_flax.py b/tests/test_scheduler_flax.py index 7928939f2d..5ada689b72 100644 --- a/tests/test_scheduler_flax.py +++ b/tests/test_scheduler_flax.py @@ -12,12 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import tempfile import unittest from typing import Dict, List, Tuple from diffusers import FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxPNDMScheduler -from diffusers.utils import is_flax_available +from diffusers.utils import deprecate, is_flax_available from diffusers.utils.testing_utils import require_flax @@ -83,7 +84,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): state = scheduler.set_timesteps(state, num_inference_steps) @@ -112,7 +113,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): state = scheduler.set_timesteps(state, num_inference_steps) @@ -140,7 +141,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): state = scheduler.set_timesteps(state, num_inference_steps) @@ -228,6 +229,27 @@ class FlaxSchedulerCommonTest(unittest.TestCase): recursive_check(outputs_tuple[0], outputs_dict.prev_sample) + def test_deprecated_kwargs(self): + for scheduler_class in self.scheduler_classes: + has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters + has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0 + + if has_kwarg_in_model_class and not has_deprecated_kwarg: + raise ValueError( + f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated" + " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if" + " there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" + " []`" + ) + + if not has_kwarg_in_model_class and has_deprecated_kwarg: + raise ValueError( + f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated" + " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`" + f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the" + " deprecated argument from `_deprecated_kwargs = []`" + ) + @require_flax class FlaxDDPMSchedulerTest(FlaxSchedulerCommonTest): @@ -373,7 +395,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): state = scheduler.set_timesteps(state, num_inference_steps) @@ -401,7 +423,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): state = scheduler.set_timesteps(state, num_inference_steps) @@ -430,7 +452,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): state = scheduler.set_timesteps(state, num_inference_steps) @@ -599,6 +621,26 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest): assert abs(result_sum - 149.0784) < 1e-2 assert abs(result_mean - 0.1941) < 1e-3 + def test_prediction_type(self): + for prediction_type in ["epsilon", "sample", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + + def test_deprecated_predict_epsilon(self): + deprecate("remove this test", "0.10.0", "remove") + for predict_epsilon in [True, False]: + self.check_over_configs(predict_epsilon=predict_epsilon) + + def test_deprecated_predict_epsilon_to_prediction_type(self): + deprecate("remove this test", "0.10.0", "remove") + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(predict_epsilon=True) + scheduler = scheduler_class.from_config(scheduler_config) + assert scheduler.prediction_type == "epsilon" + + scheduler_config = self.get_scheduler_config(predict_epsilon=False) + scheduler = scheduler_class.from_config(scheduler_config) + assert scheduler.prediction_type == "sample" + @require_flax class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest): @@ -633,7 +675,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape) # copy over dummy past residuals new_state = new_state.replace(ets=dummy_past_residuals[:]) @@ -720,7 +762,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) # copy over dummy past residuals new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape) diff --git a/tests/test_utils.py b/tests/test_utils.py index 35cf574210..761242eb9a 100755 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -26,7 +26,7 @@ class DeprecateTester(unittest.TestCase): def test_deprecate_function_arg(self): kwargs = {"deprecated_arg": 4} - with self.assertWarns(DeprecationWarning) as warning: + with self.assertWarns(FutureWarning) as warning: output = deprecate("deprecated_arg", self.higher_version, "message", take_from=kwargs) assert output == 4 @@ -39,7 +39,7 @@ class DeprecateTester(unittest.TestCase): def test_deprecate_function_arg_tuple(self): kwargs = {"deprecated_arg": 4} - with self.assertWarns(DeprecationWarning) as warning: + with self.assertWarns(FutureWarning) as warning: output = deprecate(("deprecated_arg", self.higher_version, "message"), take_from=kwargs) assert output == 4 @@ -51,7 +51,7 @@ class DeprecateTester(unittest.TestCase): def test_deprecate_function_args(self): kwargs = {"deprecated_arg_1": 4, "deprecated_arg_2": 8} - with self.assertWarns(DeprecationWarning) as warning: + with self.assertWarns(FutureWarning) as warning: output_1, output_2 = deprecate( ("deprecated_arg_1", self.higher_version, "Hey"), ("deprecated_arg_2", self.higher_version, "Hey"), @@ -81,7 +81,7 @@ class DeprecateTester(unittest.TestCase): assert "got an unexpected keyword argument `deprecated_arg`" in str(error.exception) def test_deprecate_arg_no_kwarg(self): - with self.assertWarns(DeprecationWarning) as warning: + with self.assertWarns(FutureWarning) as warning: deprecate(("deprecated_arg", self.higher_version, "message")) assert ( @@ -90,7 +90,7 @@ class DeprecateTester(unittest.TestCase): ) def test_deprecate_args_no_kwarg(self): - with self.assertWarns(DeprecationWarning) as warning: + with self.assertWarns(FutureWarning) as warning: deprecate( ("deprecated_arg_1", self.higher_version, "Hey"), ("deprecated_arg_2", self.higher_version, "Hey"), @@ -108,7 +108,7 @@ class DeprecateTester(unittest.TestCase): class Args: arg = 5 - with self.assertWarns(DeprecationWarning) as warning: + with self.assertWarns(FutureWarning) as warning: arg = deprecate(("arg", self.higher_version, "message"), take_from=Args()) assert arg == 5 @@ -122,7 +122,7 @@ class DeprecateTester(unittest.TestCase): arg = 5 foo = 7 - with self.assertWarns(DeprecationWarning) as warning: + with self.assertWarns(FutureWarning) as warning: arg_1, arg_2 = deprecate( ("arg", self.higher_version, "message"), ("foo", self.higher_version, "message"), @@ -158,7 +158,7 @@ class DeprecateTester(unittest.TestCase): ) def test_deprecate_incorrect_no_standard_warn(self): - with self.assertWarns(DeprecationWarning) as warning: + with self.assertWarns(FutureWarning) as warning: deprecate(("deprecated_arg", self.higher_version, "This message is better!!!"), standard_warn=False) assert str(warning.warning) == "This message is better!!!" diff --git a/utils/check_copies.py b/utils/check_copies.py index 395cefb9c4..16782397da 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -153,6 +153,10 @@ def is_copy_consistent(filename, overwrite=False): observed_code_lines = lines[start_index:line_index] observed_code = "".join(observed_code_lines) + # Remove any nested `Copied from` comments to avoid circular copies + theoretical_code = [line for line in theoretical_code.split("\n") if _re_copy_warning.search(line) is None] + theoretical_code = "\n".join(theoretical_code) + # Before comparing, use the `replace_pattern` on the original code. if len(replace_pattern) > 0: patterns = replace_pattern.replace("with", "").split(",")