diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 0000000000..eaca39797e --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,27 @@ +name: Stale Bot + +on: + schedule: + - cron: "0 15 * * *" + +jobs: + close_stale_issues: + name: Close Stale Issues + if: github.repository == 'huggingface/diffusers' + runs-on: ubuntu-latest + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + steps: + - uses: actions/checkout@v2 + + - name: Setup Python + uses: actions/setup-python@v1 + with: + python-version: 3.7 + + - name: Install requirements + run: | + pip install PyGithub + - name: Close stale issues + run: | + python utils/stale.py diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml new file mode 100644 index 0000000000..fbd051b4da --- /dev/null +++ b/.github/workflows/typos.yml @@ -0,0 +1,14 @@ +name: Check typos + +on: + workflow_dispatch: + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - name: typos-action + uses: crate-ci/typos@v1.12.4 diff --git a/README.md b/README.md index 434d0cee2b..5a25ce5012 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ as a modular toolbox for inference and training of diffusion models. More precisely, 🤗 Diffusers offers: - State-of-the-art diffusion pipelines that can be run in inference with just a couple of lines of code (see [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines)). Check [this overview](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/README.md#pipelines-summary) to see all supported pipelines and their corresponding official papers. -- Various noise schedulers that can be used interchangeably for the prefered speed vs. quality trade-off in inference (see [src/diffusers/schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers)). +- Various noise schedulers that can be used interchangeably for the preferred speed vs. quality trade-off in inference (see [src/diffusers/schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers)). - Multiple types of models, such as UNet, can be used as building blocks in an end-to-end diffusion system (see [src/diffusers/models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models)). - Training examples to show how to train the most popular diffusion model tasks (see [examples](https://github.com/huggingface/diffusers/tree/main/examples), *e.g.* [unconditional-image-generation](https://github.com/huggingface/diffusers/tree/main/examples/unconditional_image_generation)). @@ -30,7 +30,7 @@ More precisely, 🤗 Diffusers offers: **With `pip`** ```bash -pip install --upgrade diffusers # should install diffusers 0.2.4 +pip install --upgrade diffusers ``` **With `conda`** @@ -39,6 +39,10 @@ pip install --upgrade diffusers # should install diffusers 0.2.4 conda install -c conda-forge diffusers ``` +**Apple Silicon (M1/M2) support** + +Please, refer to [the documentation](https://huggingface.co/docs/diffusers/optimization/mps). + ## Contributing We ❤️ contributions from the open-source community! @@ -191,7 +195,7 @@ with autocast("cuda"): images[0].save("fantasy_landscape.png") ``` -You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb) +You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) ### In-painting using Stable Diffusion @@ -254,42 +258,49 @@ If you want to run the code yourself 💻, you can try out: - [Text-to-Image Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256) ```python # !pip install diffusers transformers +from torch import autocast from diffusers import DiffusionPipeline +device = "cuda" model_id = "CompVis/ldm-text2im-large-256" # load model and scheduler ldm = DiffusionPipeline.from_pretrained(model_id) +ldm = ldm.to(device) # run pipeline in inference (sample random noise and denoise) prompt = "A painting of a squirrel eating a burger" -images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images +with autocast(device): + image = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images[0] -# save images -for idx, image in enumerate(images): - image.save(f"squirrel-{idx}.png") +# save image +image.save("squirrel.png") ``` - [Unconditional Diffusion with discrete scheduler](https://huggingface.co/google/ddpm-celebahq-256) ```python # !pip install diffusers +from torch import autocast from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline model_id = "google/ddpm-celebahq-256" +device = "cuda" # load model and scheduler ddpm = DDPMPipeline.from_pretrained(model_id) # you can replace DDPMPipeline with DDIMPipeline or PNDMPipeline for faster inference +ddpm.to(device) # run pipeline in inference (sample random noise and denoise) -image = ddpm().images +with autocast("cuda"): + image = ddpm().images[0] # save image -image[0].save("ddpm_generated_image.png") +image.save("ddpm_generated_image.png") ``` - [Unconditional Latent Diffusion](https://huggingface.co/CompVis/ldm-celebahq-256) -- [Unconditional Diffusion with continous scheduler](https://huggingface.co/google/ncsnpp-ffhq-1024) +- [Unconditional Diffusion with continuous scheduler](https://huggingface.co/google/ncsnpp-ffhq-1024) **Other Notebooks**: -* [image-to-image generation with Stable Diffusion](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg), +* [image-to-image generation with Stable Diffusion](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg), * [tweak images via repeated Stable Diffusion seeds](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb) ![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg), ### Web Demos @@ -335,8 +346,8 @@ The class provides functionality to compute previous image according to alpha, b ## Philosophy -- Readability and clarity is prefered over highly optimized code. A strong importance is put on providing readable, intuitive and elementary code design. *E.g.*, the provided [schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers) are separated from the provided [models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and provide well-commented code that can be read alongside the original paper. -- Diffusers is **modality independent** and focuses on providing pretrained models and tools to build systems that generate **continous outputs**, *e.g.* vision and audio. +- Readability and clarity is preferred over highly optimized code. A strong importance is put on providing readable, intuitive and elementary code design. *E.g.*, the provided [schedulers](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers) are separated from the provided [models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and provide well-commented code that can be read alongside the original paper. +- Diffusers is **modality independent** and focuses on providing pretrained models and tools to build systems that generate **continuous outputs**, *e.g.* vision and audio. - Diffusion models and schedulers are provided as concise, elementary building blocks. In contrast, diffusion pipelines are a collection of end-to-end diffusion systems that can be used out-of-the-box, should stay as close as possible to their original implementation and can include components of another library, such as text-encoders. Examples for diffusion pipelines are [Glide](https://github.com/openai/glide-text2im) and [Latent Diffusion](https://github.com/CompVis/latent-diffusion). ## In the works diff --git a/_typos.toml b/_typos.toml new file mode 100644 index 0000000000..4025388915 --- /dev/null +++ b/_typos.toml @@ -0,0 +1,12 @@ +# Files for typos +# Instruction: https://github.com/marketplace/actions/typos-action#getting-started + +[default.extend-identifiers] + +[default.extend-words] +NIN_="NIN" # NIN is used in scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py +nd="np" # nd may be np (numpy) + + +[files] +extend-exclude = ["_typos.toml"] diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 612c449c01..3d1bd4929d 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -35,8 +35,6 @@ title: "Open Vino" - local: optimization/mps title: "MPS" - - local: optimization/other - title: "Other" title: "Optimization/Special Hardware" - sections: - local: training/overview diff --git a/docs/source/api/configuration.mdx b/docs/source/api/configuration.mdx index 5c435dc8e1..45176f55b0 100644 --- a/docs/source/api/configuration.mdx +++ b/docs/source/api/configuration.mdx @@ -10,19 +10,14 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Models +# Configuration -Diffusers contains pretrained models for popular algorithms and modules for creating the next set of diffusion models. -The primary function of these models is to denoise an input sample, by modeling the distribution $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$. -The models are built on the base class ['ModelMixin'] that is a `torch.nn.module` with basic functionality for saving and loading models both locally and from the HuggingFace hub. +In Diffusers, schedulers of type [`schedulers.scheduling_utils.SchedulerMixin`], and models of type [`ModelMixin`] inherit from [`ConfigMixin`] which conveniently takes care of storing all parameters that are +passed to the respective `__init__` methods in a JSON-configuration file. -## API +TODO(PVP) - add example and better info here -Models should provide the `def forward` function and initialization of the model. -All saving, loading, and utilities should be in the base ['ModelMixin'] class. - -## Examples - -- The ['UNetModel'] was proposed in [TODO](https://arxiv.org/) and has been used in paper1, paper2, paper3. -- Extensions of the ['UNetModel'] include the ['UNetGlideModel'] that uses attention and timestep embeddings for the [GLIDE](https://arxiv.org/abs/2112.10741) paper, the ['UNetGradTTS'] model from this [paper](https://arxiv.org/abs/2105.06337) for text-to-speech, ['UNetLDMModel'] for latent-diffusion models in this [paper](https://arxiv.org/abs/2112.10752), and the ['TemporalUNet'] used for time-series prediciton in this reinforcement learning [paper](https://arxiv.org/abs/2205.09991). -- TODO: mention VAE / SDE score estimation \ No newline at end of file +## ConfigMixin +[[autodoc]] ConfigMixin + - from_config + - save_config diff --git a/docs/source/api/diffusion_pipeline.mdx b/docs/source/api/diffusion_pipeline.mdx index 5c435dc8e1..6a0f758f76 100644 --- a/docs/source/api/diffusion_pipeline.mdx +++ b/docs/source/api/diffusion_pipeline.mdx @@ -10,19 +10,30 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Models +# Pipelines -Diffusers contains pretrained models for popular algorithms and modules for creating the next set of diffusion models. -The primary function of these models is to denoise an input sample, by modeling the distribution $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$. -The models are built on the base class ['ModelMixin'] that is a `torch.nn.module` with basic functionality for saving and loading models both locally and from the HuggingFace hub. +The [`DiffusionPipeline`] is the easiest way to load any pretrained diffusion pipeline from the [Hub](https://huggingface.co/models?library=diffusers) and to use it in inference. -## API + + + One should not use the Diffusion Pipeline class for training or fine-tuning a diffusion model. Individual + components of diffusion pipelines are usually trained individually, so we suggest to directly work + with [`UNetModel`] and [`UNetConditionModel`]. -Models should provide the `def forward` function and initialization of the model. -All saving, loading, and utilities should be in the base ['ModelMixin'] class. + -## Examples +Any diffusion pipeline that is loaded with [`~DiffusionPipeline.from_pretrained`] will automatically +detect the pipeline type, *e.g.* [`StableDiffusionPipeline`] and consequently load each component of the +pipeline and pass them into the `__init__` function of the pipeline, *e.g.* [`~StableDiffusionPipeline.__init__`]. -- The ['UNetModel'] was proposed in [TODO](https://arxiv.org/) and has been used in paper1, paper2, paper3. -- Extensions of the ['UNetModel'] include the ['UNetGlideModel'] that uses attention and timestep embeddings for the [GLIDE](https://arxiv.org/abs/2112.10741) paper, the ['UNetGradTTS'] model from this [paper](https://arxiv.org/abs/2105.06337) for text-to-speech, ['UNetLDMModel'] for latent-diffusion models in this [paper](https://arxiv.org/abs/2112.10752), and the ['TemporalUNet'] used for time-series prediciton in this reinforcement learning [paper](https://arxiv.org/abs/2205.09991). -- TODO: mention VAE / SDE score estimation \ No newline at end of file +Any pipeline object can be saved locally with [`~DiffusionPipeline.save_pretrained`]. + +## DiffusionPipeline +[[autodoc]] DiffusionPipeline + - from_pretrained + - save_pretrained + +## ImagePipelineOutput +By default diffusion pipelines return an object of class + +[[autodoc]] pipeline_utils.ImagePipelineOutput diff --git a/docs/source/api/models.mdx b/docs/source/api/models.mdx index 5c435dc8e1..525548e7c3 100644 --- a/docs/source/api/models.mdx +++ b/docs/source/api/models.mdx @@ -16,13 +16,32 @@ Diffusers contains pretrained models for popular algorithms and modules for crea The primary function of these models is to denoise an input sample, by modeling the distribution $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$. The models are built on the base class ['ModelMixin'] that is a `torch.nn.module` with basic functionality for saving and loading models both locally and from the HuggingFace hub. -## API +## ModelMixin +[[autodoc]] ModelMixin -Models should provide the `def forward` function and initialization of the model. -All saving, loading, and utilities should be in the base ['ModelMixin'] class. +## UNet2DOutput +[[autodoc]] models.unet_2d.UNet2DOutput -## Examples +## UNet2DModel +[[autodoc]] UNet2DModel -- The ['UNetModel'] was proposed in [TODO](https://arxiv.org/) and has been used in paper1, paper2, paper3. -- Extensions of the ['UNetModel'] include the ['UNetGlideModel'] that uses attention and timestep embeddings for the [GLIDE](https://arxiv.org/abs/2112.10741) paper, the ['UNetGradTTS'] model from this [paper](https://arxiv.org/abs/2105.06337) for text-to-speech, ['UNetLDMModel'] for latent-diffusion models in this [paper](https://arxiv.org/abs/2112.10752), and the ['TemporalUNet'] used for time-series prediciton in this reinforcement learning [paper](https://arxiv.org/abs/2205.09991). -- TODO: mention VAE / SDE score estimation \ No newline at end of file +## UNet2DConditionOutput +[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput + +## UNet2DConditionModel +[[autodoc]] UNet2DConditionModel + +## DecoderOutput +[[autodoc]] models.vae.DecoderOutput + +## VQEncoderOutput +[[autodoc]] models.vae.VQEncoderOutput + +## VQModel +[[autodoc]] VQModel + +## AutoencoderKLOutput +[[autodoc]] models.vae.AutoencoderKLOutput + +## AutoencoderKL +[[autodoc]] AutoencoderKL diff --git a/docs/source/api/outputs.mdx b/docs/source/api/outputs.mdx index 5c435dc8e1..010761fb2e 100644 --- a/docs/source/api/outputs.mdx +++ b/docs/source/api/outputs.mdx @@ -10,19 +10,46 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Models +# BaseOutputs -Diffusers contains pretrained models for popular algorithms and modules for creating the next set of diffusion models. -The primary function of these models is to denoise an input sample, by modeling the distribution $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$. -The models are built on the base class ['ModelMixin'] that is a `torch.nn.module` with basic functionality for saving and loading models both locally and from the HuggingFace hub. +All models have outputs that are instances of subclasses of [`~utils.BaseOutput`]. Those are +data structures containing all the information returned by the model, but that can also be used as tuples or +dictionaries. -## API +Let's see how this looks in an example: -Models should provide the `def forward` function and initialization of the model. -All saving, loading, and utilities should be in the base ['ModelMixin'] class. +```python +from diffusers import DDIMPipeline -## Examples +pipeline = DDIMPipeline.from_pretrained("google/ddpm-cifar10-32") +outputs = pipeline() +``` -- The ['UNetModel'] was proposed in [TODO](https://arxiv.org/) and has been used in paper1, paper2, paper3. -- Extensions of the ['UNetModel'] include the ['UNetGlideModel'] that uses attention and timestep embeddings for the [GLIDE](https://arxiv.org/abs/2112.10741) paper, the ['UNetGradTTS'] model from this [paper](https://arxiv.org/abs/2105.06337) for text-to-speech, ['UNetLDMModel'] for latent-diffusion models in this [paper](https://arxiv.org/abs/2112.10752), and the ['TemporalUNet'] used for time-series prediciton in this reinforcement learning [paper](https://arxiv.org/abs/2205.09991). -- TODO: mention VAE / SDE score estimation \ No newline at end of file +The `outputs` object is a [`~pipeline_utils.ImagePipelineOutput`], as we can see in the +documentation of that class below, it means it has an image attribute. + +You can access each attribute as you would usually do, and if that attribute has not been returned by the model, you will get `None`: + +```python +outputs.images +``` + +or via keyword lookup + +```python +outputs["images"] +``` + +When considering our `outputs` object as tuple, it only considers the attributes that don't have `None` values. +Here for instance, we could retrieve images via indexing: + +```python +outputs[:1] +``` + +which will return the tuple `(outputs.images)` for instance. + +## BaseOutput + +[[autodoc]] utils.BaseOutput + - to_tuple diff --git a/docs/source/api/pipelines/ddim.mdx b/docs/source/api/pipelines/ddim.mdx index 7a28c0ee10..41952c2c0d 100644 --- a/docs/source/api/pipelines/ddim.mdx +++ b/docs/source/api/pipelines/ddim.mdx @@ -17,7 +17,6 @@ The original codebase of this paper can be found [here](https://github.com/ermon | [pipeline_ddim.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddim/pipeline_ddim.py) | *Unconditional Image Generation* | - | -## API - -[[autodoc]] pipelines.ddim.pipeline_ddim.DDIMPipeline +## DDIMPipeline +[[autodoc]] DDIMPipeline - __call__ diff --git a/docs/source/api/pipelines/ddpm.mdx b/docs/source/api/pipelines/ddpm.mdx index 88ed3fbf16..b0e08f84ef 100644 --- a/docs/source/api/pipelines/ddpm.mdx +++ b/docs/source/api/pipelines/ddpm.mdx @@ -19,7 +19,6 @@ The original codebase of this paper can be found [here](https://github.com/hojon | [pipeline_ddpm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddpm/pipeline_ddpm.py) | *Unconditional Image Generation* | - | -## API - -[[autodoc]] pipelines.ddpm.pipeline_ddpm.DDPMPipeline +# DDPMPipeline +[[autodoc]] DDPMPipeline - __call__ diff --git a/docs/source/api/pipelines/latent_diffusion.mdx b/docs/source/api/pipelines/latent_diffusion.mdx index 837e931e06..821ad4fea5 100644 --- a/docs/source/api/pipelines/latent_diffusion.mdx +++ b/docs/source/api/pipelines/latent_diffusion.mdx @@ -25,7 +25,6 @@ The original codebase can be found [here](https://github.com/CompVis/latent-diff ## Examples: -## API - +## LDMTextToImagePipeline [[autodoc]] pipelines.latent_diffusion.pipeline_latent_diffusion.LDMTextToImagePipeline - __call__ diff --git a/docs/source/api/pipelines/latent_diffusion_uncond.mdx b/docs/source/api/pipelines/latent_diffusion_uncond.mdx index 5868d07756..4e12fa7f52 100644 --- a/docs/source/api/pipelines/latent_diffusion_uncond.mdx +++ b/docs/source/api/pipelines/latent_diffusion_uncond.mdx @@ -24,7 +24,6 @@ The original codebase can be found [here](https://github.com/CompVis/latent-diff ## Examples: -## API - -[[autodoc]] pipelines.latent_diffusion_uncond.pipeline_latent_diffusion_uncond.LDMPipeline +## LDMPipeline +[[autodoc]] LDMPipeline - __call__ diff --git a/docs/source/api/pipelines/overview.mdx b/docs/source/api/pipelines/overview.mdx index 30519b2d17..881e5cdbd9 100644 --- a/docs/source/api/pipelines/overview.mdx +++ b/docs/source/api/pipelines/overview.mdx @@ -51,8 +51,8 @@ available a colab notebook to directly try them out. | [score_sde_ve](./score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | | [score_sde_vp](./score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | | [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) -| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb) -| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/in_painting_with_stable_diffusion_using_diffusers.ipynb) +| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) +| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) | [stochatic_karras_ve](./stochatic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | **Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers. @@ -143,7 +143,7 @@ with autocast("cuda"): images[0].save("fantasy_landscape.png") ``` -You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb) +You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) ### Tweak prompts reusing seeds and latents @@ -187,4 +187,4 @@ with autocast("cuda"): images[0].save("cat_on_bench.png") ``` -You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/in_painting_with_stable_diffusion_using_diffusers.ipynb) +You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) diff --git a/docs/source/api/pipelines/pndm.mdx b/docs/source/api/pipelines/pndm.mdx index 734996109c..cbb7c5a929 100644 --- a/docs/source/api/pipelines/pndm.mdx +++ b/docs/source/api/pipelines/pndm.mdx @@ -17,8 +17,7 @@ The original codebase can be found [here](https://github.com/luping-liu/PNDM). | [pipeline_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pndm/pipeline_pndm.py) | *Unconditional Image Generation* | - | -## API - +## PNDMPipeline [[autodoc]] pipelines.pndm.pipeline_pndm.PNDMPipeline - __call__ diff --git a/docs/source/api/pipelines/score_sde_ve.mdx b/docs/source/api/pipelines/score_sde_ve.mdx index b4bda54ff0..2e555914ac 100644 --- a/docs/source/api/pipelines/score_sde_ve.mdx +++ b/docs/source/api/pipelines/score_sde_ve.mdx @@ -18,8 +18,7 @@ This pipeline implements the Variance Expanding (VE) variant of the method. |---|---|:---:| | [pipeline_score_sde_ve.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py) | *Unconditional Image Generation* | - | -## API - -[[autodoc]] pipelines.score_sde_ve.pipeline_score_sde_ve.ScoreSdeVePipeline +## ScoreSdeVePipeline +[[autodoc]] ScoreSdeVePipeline - __call__ diff --git a/docs/source/api/pipelines/stable_diffusion.mdx b/docs/source/api/pipelines/stable_diffusion.mdx index 0d1c01991d..3288c679de 100644 --- a/docs/source/api/pipelines/stable_diffusion.mdx +++ b/docs/source/api/pipelines/stable_diffusion.mdx @@ -1,40 +1,39 @@ # Stable diffusion pipelines -## Overview - Stable Diffusion is a text-to-image _latent diffusion_ model created by the researchers and engineers from [CompVis](https://github.com/CompVis), [Stability AI](https://stability.ai/) and [LAION](https://laion.ai/). It's trained on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) dataset. This model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts. With its 860M UNet and 123M text encoder, the model is relatively lightweight and can run on consumer GPUs. Latent diffusion is the research on top of which Stable Diffusion was built. It was proposed in [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer. You can learn more details about it in the [specific pipeline for latent diffusion](pipelines/latent_diffusion) that is part of 🤗 Diffusers. For more details about how Stable Diffusion works and how it differs from the base latent diffusion model, please refer to the official [launch announcement post](https://stability.ai/blog/stable-diffusion-announcement) and [this section of our own blog post](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work). -## Tips - +*Tips*: - To tweak your prompts on a specific result you liked, you can generate your own latents, as demonstrated in the following notebook: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb) -- TODO: some interesting Tips -## Available pipelines +*Overview*: | Pipeline | Tasks | Colab | Demo - |---|---|:---:|:---:| - | [pipeline_stable_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb) | [🤗 Stable Diffusion](https://huggingface.co/spaces/stabilityai/stable-diffusion) - | [pipeline_stable_diffusion_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb) | [🤗 Diffuse the Rest](https://huggingface.co/spaces/huggingface/diffuse-the-rest) - | [pipeline_stable_diffusion_inpaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | **Experimental** – *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/in_painting_with_stable_diffusion_using_diffusers.ipynb) | Coming soon +|---|---|:---:|:---:| +| [pipeline_stable_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb) | [🤗 Stable Diffusion](https://huggingface.co/spaces/stabilityai/stable-diffusion) +| [pipeline_stable_diffusion_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) | [🤗 Diffuse the Rest](https://huggingface.co/spaces/huggingface/diffuse-the-rest) +| [pipeline_stable_diffusion_inpaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | **Experimental** – *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) | Coming soon -## API +## StableDiffusionPipelineOutput +[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput +## StableDiffusionPipeline [[autodoc]] StableDiffusionPipeline - - __call__ - - enable_attention_slicing - - disable_attention_slicing + - __call__ + - enable_attention_slicing + - disable_attention_slicing +## StableDiffusionImg2ImgPipeline [[autodoc]] StableDiffusionImg2ImgPipeline - - __call__ - - enable_attention_slicing - - disable_attention_slicing + - __call__ + - enable_attention_slicing + - disable_attention_slicing +## StableDiffusionInpaintPipeline [[autodoc]] StableDiffusionInpaintPipeline - - __call__ - - enable_attention_slicing - - disable_attention_slicing - + - __call__ + - enable_attention_slicing + - disable_attention_slicing diff --git a/docs/source/api/pipelines/stochastic_karras_ve.mdx b/docs/source/api/pipelines/stochastic_karras_ve.mdx index f926e2871d..d1ef8f7b3f 100644 --- a/docs/source/api/pipelines/stochastic_karras_ve.mdx +++ b/docs/source/api/pipelines/stochastic_karras_ve.mdx @@ -18,7 +18,6 @@ This pipeline implements the Stochastic sampling tailored to the Variance-Expand | [pipeline_stochastic_karras_ve.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py) | *Unconditional Image Generation* | - | -## API - -[[autodoc]] pipelines.stochastic_karras_ve.pipeline_stochastic_karras_ve.KarrasVePipeline +## KarrasVePipeline +[[autodoc]] KarrasVePipeline - __call__ diff --git a/docs/source/api/schedulers.mdx b/docs/source/api/schedulers.mdx index 1deff1a4bb..2b6e58fe12 100644 --- a/docs/source/api/schedulers.mdx +++ b/docs/source/api/schedulers.mdx @@ -14,7 +14,8 @@ specific language governing permissions and limitations under the License. Diffusers contains multiple pre-built schedule functions for the diffusion process. -## What is a schduler? +## What is a scheduler? + The schedule functions, denoted *Schedulers* in the library take in the output of a trained model, a sample which the diffusion process is iterating on, and a timestep to return a denoised sample. - Schedulers define the methodology for iteratively adding noise to an image or for updating a sample based on model outputs. @@ -23,73 +24,77 @@ The schedule functions, denoted *Schedulers* in the library take in the output o - Schedulers are often defined by a *noise schedule* and an *update rule* to solve the differential equation solution. ### Discrete versus continuous schedulers + All schedulers take in a timestep to predict the updated version of the sample being diffused. The timesteps dictate where in the diffusion process the step is, where data is generated by iterating forward in time and inference is executed by propagating backwards through timesteps. -Different algorithms use timesteps that both discrete (accepting `int` inputs), such as the [`DDPMScheduler`] or [`PNDMScheduler`], and continuous (accepting 'float` inputs), such as the score-based schedulers [`ScoreSdeVeScheduler`] or [`ScoreSdeVpScheduler`]. +Different algorithms use timesteps that both discrete (accepting `int` inputs), such as the [`DDPMScheduler`] or [`PNDMScheduler`], and continuous (accepting `float` inputs), such as the score-based schedulers [`ScoreSdeVeScheduler`] or [`ScoreSdeVpScheduler`]. ## Designing Re-usable schedulers + The core design principle between the schedule functions is to be model, system, and framework independent. This allows for rapid experimentation and cleaner abstractions in the code, where the model prediction is separated from the sample update. To this end, the design of schedulers is such that: + - Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality. - Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Numpy support currently exists). ## API + The core API for any new scheduler must follow a limited structure. - Schedulers should provide one or more `def step(...)` functions that should be called to update the generated sample iteratively. - Schedulers should provide a `set_timesteps(...)` method that configures the parameters of a schedule function for a specific inference task. -- Schedulers should be framework-agonstic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch +- Schedulers should be framework-agnostic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch with a `set_format(...)` method. -### Core The base class [`SchedulerMixin`] implements low level utilities used by multiple schedulers. -#### SchedulerMixin +### SchedulerMixin [[autodoc]] SchedulerMixin -#### SchedulerOutput -The class [`SchedulerOutput`] contains the ouputs from any schedulers `step(...)` call. +### SchedulerOutput +The class [`SchedulerOutput`] contains the outputs from any schedulers `step(...)` call. + [[autodoc]] schedulers.scheduling_utils.SchedulerOutput -### Existing Schedulers +### Implemented Schedulers #### Denoising diffusion implicit models (DDIM) Original paper can be found here. -[[autodoc]] schedulers.scheduling_ddim.DDIMScheduler +[[autodoc]] DDIMScheduler #### Denoising diffusion probabilistic models (DDPM) Original paper can be found [here](https://arxiv.org/abs/2010.02502). -[[autodoc]] schedulers.scheduling_ddpm.DDPMScheduler +[[autodoc]] DDPMScheduler -#### Varience exploding, stochastic sampling from Karras et. al +#### Variance exploding, stochastic sampling from Karras et. al Original paper can be found [here](https://arxiv.org/abs/2006.11239). -[[autodoc]] schedulers.scheduling_karras_ve.KarrasVeScheduler +[[autodoc]] KarrasVeScheduler #### Linear multistep scheduler for discrete beta schedules Original implementation can be found [here](https://arxiv.org/abs/2206.00364). -[[autodoc]] schedulers.scheduling_lms_discrete.LMSDiscreteScheduler +[[autodoc]] LMSDiscreteScheduler #### Pseudo numerical methods for diffusion models (PNDM) Original implementation can be found [here](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181). -[[autodoc]] schedulers.scheduling_pndm.PNDMScheduler +[[autodoc]] PNDMScheduler #### variance exploding stochastic differential equation (SDE) scheduler Original paper can be found [here](https://arxiv.org/abs/2011.13456). -[[autodoc]] schedulers.scheduling_sde_ve.ScoreSdeVeScheduler +[[autodoc]] ScoreSdeVeScheduler #### variance preserving stochastic differential equation (SDE) scheduler diff --git a/docs/source/conceptual/stable_diffusion.mdx b/docs/source/conceptual/stable_diffusion.mdx index 044f3937b9..c00359d2ac 100644 --- a/docs/source/conceptual/stable_diffusion.mdx +++ b/docs/source/conceptual/stable_diffusion.mdx @@ -10,23 +10,8 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +# Stable Diffusion +Under construction 🚧 -# Quicktour - -Start using Diffusers🧨 quickly! -To start, use the [`DiffusionPipeline`] for quick inference and sample generations! - -``` -pip install diffusers -``` - -## Main classes - -### Models - -### Schedulers - -### Pipeliens - - +For now please visit this [very in-detail blog post](https://huggingface.co/blog/stable_diffusion) diff --git a/docs/source/index.mdx b/docs/source/index.mdx index c134925cc6..453b4a5b78 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -42,8 +42,8 @@ available a colab notebook to directly try them out. | [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | | [score_sde_vp](./api/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | | [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) -| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb) -| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/in_painting_with_stable_diffusion_using_diffusers.ipynb) +| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) +| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) | [stochatic_karras_ve](./api/pipelines/stochatic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | **Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers. diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 532796ffce..fa9466b6d1 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License. Install Diffusers for with PyTorch. Support for other libraries will come in the future -🤗 Diffusers is tested on Python 3.6+, and PyTorch 1.4.0+. +🤗 Diffusers is tested on Python 3.7+, and PyTorch 1.7.0+. ## Install with pip @@ -66,7 +66,7 @@ Clone the repository and install 🤗 Diffusers with the following commands: ```bash git clone https://github.com/huggingface/diffusers.git -cd transformers +cd diffusers pip install -e . ``` diff --git a/docs/source/optimization/onnx.mdx b/docs/source/optimization/onnx.mdx index 044f3937b9..95fd59c86d 100644 --- a/docs/source/optimization/onnx.mdx +++ b/docs/source/optimization/onnx.mdx @@ -11,22 +11,33 @@ specific language governing permissions and limitations under the License. --> +# How to use the ONNX Runtime for inference -# Quicktour +🤗 Diffusers provides a Stable Diffusion pipeline compatible with the ONNX Runtime. This allows you to run Stable Diffusion on any hardware that supports ONNX (including CPUs), and where an accelerated version of PyTorch is not available. -Start using Diffusers🧨 quickly! -To start, use the [`DiffusionPipeline`] for quick inference and sample generations! +## Installation -``` -pip install diffusers +- TODO + +## Stable Diffusion Inference + +The snippet below demonstrates how to use the ONNX runtime. You need to use `StableDiffusionOnnxPipeline` instead of `StableDiffusionPipeline`. You also need to download the weights from the `onnx` branch of the repository, and indicate the runtime provider you want to use. + +```python +# make sure you're logged in with `huggingface-cli login` +from diffusers import StableDiffusionOnnxPipeline + +pipe = StableDiffusionOnnxPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + revision="onnx", + provider="CUDAExecutionProvider", + use_auth_token=True, +) + +prompt = "a photo of an astronaut riding a horse on mars" +image = pipe(prompt).images[0] ``` -## Main classes - -### Models - -### Schedulers - -### Pipeliens - +## Known Issues +- Generating multiple prompts in a batch seems to take too much memory. While we look into it, you may need to iterate instead of batching. diff --git a/docs/source/optimization/open_vino.mdx b/docs/source/optimization/open_vino.mdx index 044f3937b9..da6878c124 100644 --- a/docs/source/optimization/open_vino.mdx +++ b/docs/source/optimization/open_vino.mdx @@ -10,23 +10,6 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +# OpenVINO - -# Quicktour - -Start using Diffusers🧨 quickly! -To start, use the [`DiffusionPipeline`] for quick inference and sample generations! - -``` -pip install diffusers -``` - -## Main classes - -### Models - -### Schedulers - -### Pipeliens - - +Under construction 🚧 diff --git a/docs/source/optimization/other.mdx b/docs/source/optimization/other.mdx deleted file mode 100644 index 044f3937b9..0000000000 --- a/docs/source/optimization/other.mdx +++ /dev/null @@ -1,32 +0,0 @@ - - - - -# Quicktour - -Start using Diffusers🧨 quickly! -To start, use the [`DiffusionPipeline`] for quick inference and sample generations! - -``` -pip install diffusers -``` - -## Main classes - -### Models - -### Schedulers - -### Pipeliens - - diff --git a/docs/source/quicktour.mdx b/docs/source/quicktour.mdx index 82caa170e1..9574ecac4a 100644 --- a/docs/source/quicktour.mdx +++ b/docs/source/quicktour.mdx @@ -86,11 +86,11 @@ just like we did before only that now you need to pass your `AUTH_TOKEN`: >>> generator = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=AUTH_TOKEN) ``` -If you do not pass your authentification token you will see that the diffusion system will not be correctly -downloaded. Forcing the user to pass an authentification token ensures that it can be verified that the +If you do not pass your authentication token you will see that the diffusion system will not be correctly +downloaded. Forcing the user to pass an authentication token ensures that it can be verified that the user has indeed read and accepted the license, which also means that an internet connection is required. -**Note**: If you do not want to be forced to pass an authentification token, you can also simply download +**Note**: If you do not want to be forced to pass an authentication token, you can also simply download the weights locally via: ``` @@ -98,7 +98,7 @@ git lfs install git clone https://huggingface.co/CompVis/stable-diffusion-v1-4 ``` -and then load locally saved weights into the pipeline. This way, you do not need to pass an authentification +and then load locally saved weights into the pipeline. This way, you do not need to pass an authentication token. Assuming that `"./stable-diffusion-v1-4"` is the local path to the cloned stable-diffusion-v1-4 repo, you can also load the pipeline as follows: @@ -114,7 +114,7 @@ Running the pipeline is then identical to the code above as it's the same model >>> image.save("image_of_squirrel_painting.png") ``` -Diffusion systems can be used with multiple different [schedulers]("api/schedulers") each with their +Diffusion systems can be used with multiple different [schedulers](./api/schedulers) each with their pros and cons. By default, Stable Diffusion runs with [`PNDMScheduler`], but it's very simple to use a different scheduler. *E.g.* if you would instead like to use the [`LMSDiscreteScheduler`] scheduler, you could use it as follows: @@ -131,15 +131,15 @@ you could use it as follows: [Stability AI's](https://stability.ai/) Stable Diffusion model is an impressive image generation model and can do much more than just generating images from text. We have dedicated a whole documentation page, -just for Stable Diffusion [here]("./conceptual/stable_diffusion"). +just for Stable Diffusion [here](./conceptual/stable_diffusion). If you want to know how to optimize Stable Diffusion to run on less memory, higher inference speeds, on specific hardware, such as Mac, or with [ONNX Runtime](https://onnxruntime.ai/), please have a look at our optimization pages: -- [Optimized PyTorch on GPU]("./optimization/fp16") -- [Mac OS with PyTorch]("./optimization/mps") -- [ONNX]("./optimization/onnx) -- [Other clever optimization tricks]("./optimization/other) +- [Optimized PyTorch on GPU](./optimization/fp16) +- [Mac OS with PyTorch](./optimization/mps) +- [ONNX](./optimization/onnx) +- [OpenVINO](./optimization/open_vino) If you want to fine-tune or train your diffusion model, please have a look at the [**training section**](./training/overview) diff --git a/docs/source/training/overview.mdx b/docs/source/training/overview.mdx index 626059d86d..c403298493 100644 --- a/docs/source/training/overview.mdx +++ b/docs/source/training/overview.mdx @@ -10,10 +10,60 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Training +# 🧨 Diffusers Training Examples -You can train diffusion models in multiple ways: +Diffusers examples are a collection of scripts to demonstrate how to effectively use the `diffusers` library +for a variety of use cases. + +**Note**: If you are looking for **official** examples on how to use `diffusers` for inference, +please have a look at [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines) + +Our examples aspire to be **self-contained**, **easy-to-tweak**, **beginner-friendly** and for **one-purpose-only**. +More specifically, this means: + +- **Self-contained**: An example script shall only depend on "pip-install-able" Python packages that can be found in a `requirements.txt` file. Example scripts shall **not** depend on any local files. This means that one can simply download an example script, *e.g.* [train_unconditional.py](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py), install the required dependencies, *e.g.* [requirements.txt](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/requirements.txt) and execute the example script. +- **Easy-to-tweak**: While we strive to present as many use cases as possible, the example scripts are just that - examples. It is expected that they won't work out-of-the box on your specific problem and that you will be required to change a few lines of code to adapt them to your needs. To help you with that, most of the examples fully expose the preprocessing of the data and the training loop to allow you to tweak and edit them as required. +- **Beginner-friendly**: We do not aim for providing state-of-the-art training scripts for the newest models, but rather examples that can be used as a way to better understand diffusion models and how to use them with the `diffusers` library. We often purposefully leave out certain state-of-the-art methods if we consider them too complex for beginners. +- **One-purpose-only**: Examples should show one task and one task only. Even if a task is from a modeling +point of view very similar, *e.g.* image super-resolution and image modification tend to use the same model and training method, we want examples to showcase only one task to keep them as readable and easy-to-understand as possible. + +We provide **official** examples that cover the most popular tasks of diffusion models. +*Official* examples are **actively** maintained by the `diffusers` maintainers and we try to rigorously follow our example philosophy as defined above. +If you feel like another important example should exist, we are more than happy to welcome a [Feature Request](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=) or directly a [Pull Request](https://github.com/huggingface/diffusers/compare) from you! + +Training examples show how to pretrain or fine-tune diffusion models for a variety of tasks. Currently we support: - [Unconditional Training](./unconditional_training) - [Text-to-Image Training](./text2image) - [Text Inversion](./text_inversion) + + +| Task | 🤗 Accelerate | 🤗 Datasets | Colab +|---|---|:---:|:---:| +| [**Unconditional Image Generation**](./unconditional_training) | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) +| [**Text-to-Image**](./text2image) | - | - | +| [**Text-Inversion**](./text_inversion) | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb) + +## Community + +In addition, we provide **community** examples, which are examples added and maintained by our community. +Community examples can consist of both *training* examples or *inference* pipelines. +For such examples, we are more lenient regarding the philosophy defined above and also cannot guarantee to provide maintenance for every issue. +Examples that are useful for the community, but are either not yet deemed popular or not yet following our above philosophy should go into the [community examples](https://github.com/huggingface/diffusers/tree/main/examples/community) folder. The community folder therefore includes training examples and inference pipelines. +**Note**: Community examples can be a [great first contribution](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) to show to the community how you like to use `diffusers` 🪄. + +## Important note + +To make sure you can successfully run the latest versions of the example scripts, you have to **install the library from source** and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: + +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install . +``` + +Then cd in the example folder of your choice and run + +```bash +pip install -r requirements.txt +``` diff --git a/docs/source/training/text2image.mdx b/docs/source/training/text2image.mdx index 044f3937b9..dcbdece429 100644 --- a/docs/source/training/text2image.mdx +++ b/docs/source/training/text2image.mdx @@ -11,22 +11,6 @@ specific language governing permissions and limitations under the License. --> +# Text-to-Image Training -# Quicktour - -Start using Diffusers🧨 quickly! -To start, use the [`DiffusionPipeline`] for quick inference and sample generations! - -``` -pip install diffusers -``` - -## Main classes - -### Models - -### Schedulers - -### Pipeliens - - +Under construction 🚧 diff --git a/docs/source/training/text_inversion.mdx b/docs/source/training/text_inversion.mdx index 107cd706f4..8c53421e21 100644 --- a/docs/source/training/text_inversion.mdx +++ b/docs/source/training/text_inversion.mdx @@ -49,7 +49,7 @@ The `textual_inversion.py` script [here](https://github.com/huggingface/diffuser ### Installing the dependencies -Before running the scipts, make sure to install the library's training dependencies: +Before running the scripts, make sure to install the library's training dependencies: ```bash pip install diffusers[training] accelerate transformers @@ -68,7 +68,7 @@ You need to accept the model license before downloading or using the weights. In You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens). -Run the following command to autheticate your token +Run the following command to authenticate your token ```bash huggingface-cli login diff --git a/docs/source/training/unconditional_training.mdx b/docs/source/training/unconditional_training.mdx index 044f3937b9..e711e05973 100644 --- a/docs/source/training/unconditional_training.mdx +++ b/docs/source/training/unconditional_training.mdx @@ -10,23 +10,140 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +# Unconditional Image-Generation +In this section, we explain how one can train an unconditional image generation diffusion +model. "Unconditional" because the model is not conditioned on any context to generate an image - once trained the model will simply generate images that resemble its training data +distribution. -# Quicktour +## Installing the dependencies -Start using Diffusers🧨 quickly! -To start, use the [`DiffusionPipeline`] for quick inference and sample generations! +Before running the scripts, make sure to install the library's training dependencies: -``` -pip install diffusers +```bash +pip install diffusers[training] accelerate datasets ``` -## Main classes +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: -### Models +```bash +accelerate config +``` -### Schedulers +## Unconditional Flowers -### Pipeliens +The command to train a DDPM UNet model on the Oxford Flowers dataset: + +```bash +accelerate launch train_unconditional.py \ + --dataset_name="huggan/flowers-102-categories" \ + --resolution=64 \ + --output_dir="ddpm-ema-flowers-64" \ + --train_batch_size=16 \ + --num_epochs=100 \ + --gradient_accumulation_steps=1 \ + --learning_rate=1e-4 \ + --lr_warmup_steps=500 \ + --mixed_precision=no \ + --push_to_hub +``` +An example trained model: https://huggingface.co/anton-l/ddpm-ema-flowers-64 + +A full training run takes 2 hours on 4xV100 GPUs. + + + +## Unconditional Pokemon + +The command to train a DDPM UNet model on the Pokemon dataset: + +```bash +accelerate launch train_unconditional.py \ + --dataset_name="huggan/pokemon" \ + --resolution=64 \ + --output_dir="ddpm-ema-pokemon-64" \ + --train_batch_size=16 \ + --num_epochs=100 \ + --gradient_accumulation_steps=1 \ + --learning_rate=1e-4 \ + --lr_warmup_steps=500 \ + --mixed_precision=no \ + --push_to_hub +``` +An example trained model: https://huggingface.co/anton-l/ddpm-ema-pokemon-64 + +A full training run takes 2 hours on 4xV100 GPUs. + + +## Using your own data + +To use your own dataset, there are 2 ways: +- you can either provide your own folder as `--train_data_dir` +- or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument. + +**Note**: If you want to create your own training dataset please have a look at [this document](https://huggingface.co/docs/datasets/image_process#image-datasets). + +Below, we explain both in more detail. + +### Provide the dataset as a folder + +If you provide your own folders with images, the script expects the following directory structure: + +```bash +data_dir/xxx.png +data_dir/xxy.png +data_dir/[...]/xxz.png +``` + +In other words, the script will take care of gathering all images inside the folder. You can then run the script like this: + +```bash +accelerate launch train_unconditional.py \ + --train_data_dir \ + +``` + +Internally, the script will use the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature which will automatically turn the folders into 🤗 Dataset objects. + +### Upload your data to the hub, as a (possibly private) repo + +It's very easy (and convenient) to upload your image dataset to the hub using the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature available in 🤗 Datasets. Simply do the following: + +```python +from datasets import load_dataset + +# example 1: local folder +dataset = load_dataset("imagefolder", data_dir="path_to_your_folder") + +# example 2: local files (supported formats are tar, gzip, zip, xz, rar, zstd) +dataset = load_dataset("imagefolder", data_files="path_to_zip_file") + +# example 3: remote files (supported formats are tar, gzip, zip, xz, rar, zstd) +dataset = load_dataset( + "imagefolder", + data_files="https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip", +) + +# example 4: providing several splits +dataset = load_dataset( + "imagefolder", data_files={"train": ["path/to/file1", "path/to/file2"], "test": ["path/to/file3", "path/to/file4"]} +) +``` + +`ImageFolder` will create an `image` column containing the PIL-encoded images. + +Next, push it to the hub! + +```python +# assuming you have ran the huggingface-cli login command in a terminal +dataset.push_to_hub("name_of_your_dataset") + +# if you want to push to a private repo, simply pass private=True: +dataset.push_to_hub("name_of_your_dataset", private=True) +``` + +and that's it! You can now train your model by simply setting the `--dataset_name` argument to the name of your dataset on the hub. + +More on this can also be found in [this blog post](https://huggingface.co/blog/image-search-datasets). diff --git a/docs/source/using-diffusers/conditional_image_generation.mdx b/docs/source/using-diffusers/conditional_image_generation.mdx index 044f3937b9..6273a71d4c 100644 --- a/docs/source/using-diffusers/conditional_image_generation.mdx +++ b/docs/source/using-diffusers/conditional_image_generation.mdx @@ -10,23 +10,39 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +# Conditional Image Generation +The [`DiffusionPipeline`] is the easiest way to use a pre-trained diffusion system for inference -# Quicktour +Start by creating an instance of [`DiffusionPipeline`] and specify which pipeline checkpoint you would like to download. +You can use the [`DiffusionPipeline`] for any [Diffusers' checkpoint](https://huggingface.co/models?library=diffusers&sort=downloads). +In this guide though, you'll use [`DiffusionPipeline`] for text-to-image generation with [Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256): -Start using Diffusers🧨 quickly! -To start, use the [`DiffusionPipeline`] for quick inference and sample generations! +```python +>>> from diffusers import DiffusionPipeline +>>> generator = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") ``` -pip install diffusers +The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components. +Because the model consists of roughly 1.4 billion parameters, we strongly recommend running it on GPU. +You can move the generator object to GPU, just like you would in PyTorch. + +```python +>>> generator.to("cuda") ``` -## Main classes +Now you can use the `generator` on your text prompt: -### Models +```python +>>> image = generator("An image of a squirrel in Picasso style").images[0] +``` -### Schedulers +The output is by default wrapped into a [PIL Image object](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class). -### Pipeliens +You can save the image by simply calling: + +```python +>>> image.save("image_of_squirrel_painting.png") +``` diff --git a/docs/source/using-diffusers/custom.mdx b/docs/source/using-diffusers/custom.mdx index 044f3937b9..2d17adaff7 100644 --- a/docs/source/using-diffusers/custom.mdx +++ b/docs/source/using-diffusers/custom.mdx @@ -10,23 +10,6 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +# Custom Pipeline - -# Quicktour - -Start using Diffusers🧨 quickly! -To start, use the [`DiffusionPipeline`] for quick inference and sample generations! - -``` -pip install diffusers -``` - -## Main classes - -### Models - -### Schedulers - -### Pipeliens - - +Under construction 🚧 diff --git a/docs/source/using-diffusers/img2img.mdx b/docs/source/using-diffusers/img2img.mdx index 044f3937b9..e3b0687144 100644 --- a/docs/source/using-diffusers/img2img.mdx +++ b/docs/source/using-diffusers/img2img.mdx @@ -10,23 +10,37 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +# Text-Guided Image-to-Image Generation +The [`StableDiffusionImg2ImgPipeline`] lets you pass a text prompt and an initial image to condition the generation of new images. -# Quicktour +```python +from torch import autocast +import requests +from PIL import Image +from io import BytesIO -Start using Diffusers🧨 quickly! -To start, use the [`DiffusionPipeline`] for quick inference and sample generations! +from diffusers import StableDiffusionImg2ImgPipeline +# load the pipeline +device = "cuda" +pipe = StableDiffusionImg2ImgPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=True +).to(device) + +# let's download an initial image +url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + +response = requests.get(url) +init_image = Image.open(BytesIO(response.content)).convert("RGB") +init_image = init_image.resize((768, 512)) + +prompt = "A fantasy landscape, trending on artstation" + +with autocast("cuda"): + images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images + +images[0].save("fantasy_landscape.png") ``` -pip install diffusers -``` - -## Main classes - -### Models - -### Schedulers - -### Pipeliens - +You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) diff --git a/docs/source/using-diffusers/inpaint.mdx b/docs/source/using-diffusers/inpaint.mdx index 044f3937b9..215b2c8073 100644 --- a/docs/source/using-diffusers/inpaint.mdx +++ b/docs/source/using-diffusers/inpaint.mdx @@ -10,23 +10,41 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +# Text-Guided Image-Inpainting + +The [`StableDiffusionInpaintPipeline`] lets you edit specific parts of an image by providing a mask and text prompt. + +```python +from io import BytesIO + +from torch import autocast +import requests +import PIL + +from diffusers import StableDiffusionInpaintPipeline -# Quicktour +def download_image(url): + response = requests.get(url) + return PIL.Image.open(BytesIO(response.content)).convert("RGB") -Start using Diffusers🧨 quickly! -To start, use the [`DiffusionPipeline`] for quick inference and sample generations! -``` -pip install diffusers +img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" +mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + +init_image = download_image(img_url).resize((512, 512)) +mask_image = download_image(mask_url).resize((512, 512)) + +device = "cuda" +pipe = StableDiffusionInpaintPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=True +).to(device) + +prompt = "a cat sitting on a bench" +with autocast("cuda"): + images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images + +images[0].save("cat_on_bench.png") ``` -## Main classes - -### Models - -### Schedulers - -### Pipeliens - - +You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) diff --git a/docs/source/using-diffusers/loading.mdx b/docs/source/using-diffusers/loading.mdx index 044f3937b9..44b514bbb2 100644 --- a/docs/source/using-diffusers/loading.mdx +++ b/docs/source/using-diffusers/loading.mdx @@ -10,23 +10,6 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> +# Loading - -# Quicktour - -Start using Diffusers🧨 quickly! -To start, use the [`DiffusionPipeline`] for quick inference and sample generations! - -``` -pip install diffusers -``` - -## Main classes - -### Models - -### Schedulers - -### Pipeliens - - +Under construction 🚧 diff --git a/docs/source/using-diffusers/unconditional_image_generation.mdx b/docs/source/using-diffusers/unconditional_image_generation.mdx index 044f3937b9..8f5449f8fb 100644 --- a/docs/source/using-diffusers/unconditional_image_generation.mdx +++ b/docs/source/using-diffusers/unconditional_image_generation.mdx @@ -12,21 +12,41 @@ specific language governing permissions and limitations under the License. -# Quicktour +# Unonditional Image Generation -Start using Diffusers🧨 quickly! -To start, use the [`DiffusionPipeline`] for quick inference and sample generations! +The [`DiffusionPipeline`] is the easiest way to use a pre-trained diffusion system for inference +Start by creating an instance of [`DiffusionPipeline`] and specify which pipeline checkpoint you would like to download. +You can use the [`DiffusionPipeline`] for any [Diffusers' checkpoint](https://huggingface.co/models?library=diffusers&sort=downloads). +In this guide though, you'll use [`DiffusionPipeline`] for unconditional image generation with [DDPM](https://arxiv.org/abs/2006.11239): + +```python +>>> from diffusers import DiffusionPipeline + +>>> generator = DiffusionPipeline.from_pretrained("google/ddpm-celebahq-256") ``` -pip install diffusers +The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components. +Because the model consists of roughly 1.4 billion parameters, we strongly recommend running it on GPU. +You can move the generator object to GPU, just like you would in PyTorch. + +```python +>>> generator.to("cuda") ``` -## Main classes +Now you can use the `generator` on your text prompt: -### Models +```python +>>> image = generator().images[0] +``` + +The output is by default wrapped into a [PIL Image object](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class). + +You can save the image by simply calling: + +```python +>>> image.save("generated_image.png") +``` -### Schedulers -### Pipeliens diff --git a/examples/textual_inversion/README.md b/examples/textual_inversion/README.md index ad3b405d0f..65b9d4958b 100644 --- a/examples/textual_inversion/README.md +++ b/examples/textual_inversion/README.md @@ -14,7 +14,7 @@ Colab for inference ## Running locally ### Installing the dependencies -Before running the scipts, make sure to install the library's training dependencies: +Before running the scripts, make sure to install the library's training dependencies: ```bash pip install diffusers[training] accelerate transformers @@ -33,7 +33,7 @@ You need to accept the model license before downloading or using the weights. In You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens). -Run the following command to autheticate your token +Run the following command to authenticate your token ```bash huggingface-cli login diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 6a9b6a2375..de5761646a 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -238,7 +238,6 @@ class TextualInversionDataset(Dataset): placeholder_token="*", center_crop=False, ): - self.data_root = data_root self.tokenizer = tokenizer self.learnable_property = learnable_property @@ -423,7 +422,7 @@ def main(): eps=args.adam_epsilon, ) - # TODO (patil-suraj): laod scheduler using args + # TODO (patil-suraj): load scheduler using args noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" ) @@ -505,7 +504,9 @@ def main(): noise = torch.randn(latents.shape).to(latents.device) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long() + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ).long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) diff --git a/examples/unconditional_image_generation/README.md b/examples/unconditional_image_generation/README.md index ad61058138..a2a5a840af 100644 --- a/examples/unconditional_image_generation/README.md +++ b/examples/unconditional_image_generation/README.md @@ -4,7 +4,7 @@ Creating a training image set is [described in a different document](https://hug ### Installing the dependencies -Before running the scipts, make sure to install the library's training dependencies: +Before running the scripts, make sure to install the library's training dependencies: ```bash pip install diffusers[training] accelerate datasets @@ -102,7 +102,7 @@ from datasets import load_dataset # example 1: local folder dataset = load_dataset("imagefolder", data_dir="path_to_your_folder") -# example 2: local files (suppoted formats are tar, gzip, zip, xz, rar, zstd) +# example 2: local files (supported formats are tar, gzip, zip, xz, rar, zstd) dataset = load_dataset("imagefolder", data_files="path_to_zip_file") # example 3: remote files (supported formats are tar, gzip, zip, xz, rar, zstd) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index fe4e9b0d27..f6affe8a14 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -130,7 +130,7 @@ def main(args): bsz = clean_images.shape[0] # Sample a random timestep for each image timesteps = torch.randint( - 0, noise_scheduler.num_train_timesteps, (bsz,), device=clean_images.device + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device ).long() # Add noise to the clean images according to the noise magnitude at each timestep diff --git a/scripts/convert_ddpm_original_checkpoint_to_diffusers.py b/scripts/convert_ddpm_original_checkpoint_to_diffusers.py index 52d75c75ef..4222327c23 100644 --- a/scripts/convert_ddpm_original_checkpoint_to_diffusers.py +++ b/scripts/convert_ddpm_original_checkpoint_to_diffusers.py @@ -22,7 +22,7 @@ def renew_resnet_paths(old_list, n_shave_prefix_segments=0): new_item = old_item new_item = new_item.replace("block.", "resnets.") new_item = new_item.replace("conv_shorcut", "conv1") - new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = new_item.replace("in_shortcut", "conv_shortcut") new_item = new_item.replace("temb_proj", "time_emb_proj") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py new file mode 100644 index 0000000000..ee7fc33543 --- /dev/null +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -0,0 +1,690 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Conversion script for the LDM checkpoints. """ + +import argparse +import os + +import torch + + +try: + from omegaconf import OmegaConf +except ImportError: + raise ImportError( + "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`." + ) + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + LDMTextToImagePipeline, + LMSDiscreteScheduler, + PNDMScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker +from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "query.weight") + new_item = new_item.replace("q.bias", "query.bias") + + new_item = new_item.replace("k.weight", "key.weight") + new_item = new_item.replace("k.bias", "key.bias") + + new_item = new_item.replace("v.weight", "value.weight") + new_item = new_item.replace("v.bias", "value.bias") + + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming + to them. It splits attention layers, and takes into account additional replacements + that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_unet_diffusers_config(original_config): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + unet_params = original_config.model.params.unet_config.params + + block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + config = dict( + sample_size=unet_params.image_size, + in_channels=unet_params.in_channels, + out_channels=unet_params.out_channels, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + layers_per_block=unet_params.num_res_blocks, + cross_attention_dim=unet_params.context_dim, + attention_head_dim=unet_params.num_heads, + ) + + return config + + +def create_vae_diffusers_config(original_config): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + vae_params = original_config.model.params.first_stage_config.params.ddconfig + _ = original_config.model.params.first_stage_config.params.embed_dim + + block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = dict( + sample_size=vae_params.resolution, + in_channels=vae_params.in_channels, + out_channels=vae_params.out_ch, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + latent_channels=vae_params.z_channels, + layers_per_block=vae_params.num_res_blocks, + ) + return config + + +def create_diffusers_schedular(original_config): + schedular = DDIMScheduler( + num_train_timesteps=original_config.model.params.timesteps, + beta_start=original_config.model.params.linear_start, + beta_end=original_config.model.params.linear_end, + beta_schedule="scaled_linear", + ) + return schedular + + +def create_ldm_bert_config(original_config): + bert_params = original_config.model.parms.cond_stage_config.params + config = LDMBertConfig( + d_model=bert_params.n_embed, + encoder_layers=bert_params.n_layer, + encoder_ffn_dim=bert_params.n_embed * 4, + ) + return config + + +def convert_ldm_unet_checkpoint(checkpoint, config): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + unet_key = "model.diffusion_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if ["conv.weight", "conv.bias"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + return new_checkpoint + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + vae_key = "first_stage_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def convert_ldm_bert_checkpoint(checkpoint, config): + def _copy_attn_layer(hf_attn_layer, pt_attn_layer): + hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight + hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight + hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight + + hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight + hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias + + def _copy_linear(hf_linear, pt_linear): + hf_linear.weight = pt_linear.weight + hf_linear.bias = pt_linear.bias + + def _copy_layer(hf_layer, pt_layer): + # copy layer norms + _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) + _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) + + # copy attn + _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) + + # copy MLP + pt_mlp = pt_layer[1][1] + _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) + _copy_linear(hf_layer.fc2, pt_mlp.net[2]) + + def _copy_layers(hf_layers, pt_layers): + for i, hf_layer in enumerate(hf_layers): + if i != 0: + i += i + pt_layer = pt_layers[i : i + 2] + _copy_layer(hf_layer, pt_layer) + + hf_model = LDMBertModel(config).eval() + + # copy embeds + hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight + hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight + + # copy layer norm + _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) + + # copy hidden layers + _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) + + _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) + + return hf_model + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml + parser.add_argument( + "--original_config_file", + default=None, + type=str, + help="The YAML config file corresponding to the original architecture.", + ) + parser.add_argument( + "--scheduler_type", + default="pndm", + type=str, + help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']", + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + + args = parser.parse_args() + + if args.original_config_file is None: + os.system( + "wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" + ) + args.original_config_file = "./v1-inference.yaml" + + original_config = OmegaConf.load(args.original_config_file) + checkpoint = torch.load(args.checkpoint_path)["state_dict"] + + num_train_timesteps = original_config.model.params.timesteps + beta_start = original_config.model.params.linear_start + beta_end = original_config.model.params.linear_end + if args.scheduler_type == "pndm": + scheduler = PNDMScheduler( + beta_end=beta_end, + beta_schedule="scaled_linear", + beta_start=beta_start, + num_train_timesteps=num_train_timesteps, + skip_prk_steps=True, + ) + elif args.scheduler_type == "lms": + scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear") + elif args.scheduler_type == "ddim": + scheduler = DDIMScheduler( + beta_start=beta_start, + beta_end=beta_end, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + else: + raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!") + + # Convert the UNet2DConditionModel model. + unet_config = create_unet_diffusers_config(original_config) + converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config) + + unet = UNet2DConditionModel(**unet_config) + unet.load_state_dict(converted_unet_checkpoint) + + # Convert the VAE model. + vae_config = create_vae_diffusers_config(original_config) + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + + # Convert the text model. + text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] + if text_model_type == "FrozenCLIPEmbedder": + text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") + feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") + pipe = StableDiffusionPipeline( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + else: + text_config = create_ldm_bert_config(original_config) + text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) + tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") + pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + + pipe.save_pretrained(args.dump_path) diff --git a/scripts/convert_stable_diffusion_checkpoint_to_onnx.py b/scripts/convert_stable_diffusion_checkpoint_to_onnx.py new file mode 100644 index 0000000000..0e4550b788 --- /dev/null +++ b/scripts/convert_stable_diffusion_checkpoint_to_onnx.py @@ -0,0 +1,196 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from pathlib import Path + +import torch +from torch.onnx import export + +from diffusers import StableDiffusionOnnxPipeline, StableDiffusionPipeline +from diffusers.onnx_utils import OnnxRuntimeModel +from packaging import version + + +is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11") + + +def onnx_export( + model, + model_args: tuple, + output_path: Path, + ordered_input_names, + output_names, + dynamic_axes, + opset, + use_external_data_format=False, +): + output_path.parent.mkdir(parents=True, exist_ok=True) + # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11, + # so we check the torch version for backwards compatibility + if is_torch_less_than_1_11: + export( + model, + model_args, + f=output_path.as_posix(), + input_names=ordered_input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + do_constant_folding=True, + use_external_data_format=use_external_data_format, + enable_onnx_checker=True, + opset_version=opset, + ) + else: + export( + model, + model_args, + f=output_path.as_posix(), + input_names=ordered_input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + do_constant_folding=True, + opset_version=opset, + ) + + +@torch.no_grad() +def convert_models(model_path: str, output_path: str, opset: int): + pipeline = StableDiffusionPipeline.from_pretrained(model_path, use_auth_token=True) + output_path = Path(output_path) + + # TEXT ENCODER + text_input = pipeline.tokenizer( + "A sample prompt", + padding="max_length", + max_length=pipeline.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + onnx_export( + pipeline.text_encoder, + # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files + model_args=(text_input.input_ids.to(torch.int32)), + output_path=output_path / "text_encoder" / "model.onnx", + ordered_input_names=["input_ids"], + output_names=["last_hidden_state", "pooler_output"], + dynamic_axes={ + "input_ids": {0: "batch", 1: "sequence"}, + }, + opset=opset, + ) + + # UNET + onnx_export( + pipeline.unet, + model_args=(torch.randn(2, 4, 64, 64), torch.LongTensor([0, 1]), torch.randn(2, 77, 768), False), + output_path=output_path / "unet" / "model.onnx", + ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"], + output_names=["out_sample"], # has to be different from "sample" for correct tracing + dynamic_axes={ + "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + "timestep": {0: "batch"}, + "encoder_hidden_states": {0: "batch", 1: "sequence"}, + }, + opset=opset, + use_external_data_format=True, # UNet is > 2GB, so the weights need to be split + ) + + # VAE ENCODER + vae_encoder = pipeline.vae + # need to get the raw tensor output (sample) from the encoder + vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample() + onnx_export( + vae_encoder, + model_args=(torch.randn(1, 3, 512, 512), False), + output_path=output_path / "vae_encoder" / "model.onnx", + ordered_input_names=["sample", "return_dict"], + output_names=["latent_sample"], + dynamic_axes={ + "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, + opset=opset, + ) + + # VAE DECODER + vae_decoder = pipeline.vae + # forward only through the decoder part + vae_decoder.forward = vae_encoder.decode + onnx_export( + vae_decoder, + model_args=(torch.randn(1, 4, 64, 64), False), + output_path=output_path / "vae_decoder" / "model.onnx", + ordered_input_names=["latent_sample", "return_dict"], + output_names=["sample"], + dynamic_axes={ + "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, + opset=opset, + ) + + # SAFETY CHECKER + safety_checker = pipeline.safety_checker + safety_checker.forward = safety_checker.forward_onnx + onnx_export( + pipeline.safety_checker, + model_args=(torch.randn(1, 3, 224, 224), torch.randn(1, 512, 512, 3)), + output_path=output_path / "safety_checker" / "model.onnx", + ordered_input_names=["clip_input", "images"], + output_names=["out_images", "has_nsfw_concepts"], + dynamic_axes={ + "clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + "images": {0: "batch", 1: "channels", 2: "height", 3: "width"}, + }, + opset=opset, + ) + + onnx_pipeline = StableDiffusionOnnxPipeline( + vae_decoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_decoder"), + text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"), + tokenizer=pipeline.tokenizer, + unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"), + scheduler=pipeline.scheduler, + safety_checker=OnnxRuntimeModel.from_pretrained(output_path / "safety_checker"), + feature_extractor=pipeline.feature_extractor, + ) + + onnx_pipeline.save_pretrained(output_path) + print("ONNX pipeline saved to", output_path) + + _ = StableDiffusionOnnxPipeline.from_pretrained(output_path, provider="CPUExecutionProvider") + print("ONNX pipeline is loadable") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--model_path", + type=str, + required=True, + help="Path to the `diffusers` checkpoint to convert (either a local directory or on the Hub).", + ) + + parser.add_argument("--output_path", type=str, required=True, help="Path to the output model.") + + parser.add_argument( + "--opset", + default=14, + type=str, + help="The version of the ONNX operator set to use.", + ) + + args = parser.parse_args() + + convert_models(args.model_path, args.output_path, args.opset) diff --git a/scripts/generate_logits.py b/scripts/generate_logits.py index 47dc5485af..531e2e4d71 100644 --- a/scripts/generate_logits.py +++ b/scripts/generate_logits.py @@ -124,4 +124,4 @@ for mod in models: assert torch.allclose( logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3 ) - print(f"{mod.modelId} has passed succesfully!!!") + print(f"{mod.modelId} has passed successfully!!!") diff --git a/setup.py b/setup.py index 7b71bd70d4..a9022ea7b6 100644 --- a/setup.py +++ b/setup.py @@ -68,6 +68,7 @@ To create the package for pypi. """ import re +import os from distutils.core import Command from setuptools import find_packages, setup @@ -78,14 +79,17 @@ from setuptools import find_packages, setup _deps = [ "Pillow", "accelerate>=0.11.0", - "black==22.3", + "black==22.8", "datasets", "filelock", "flake8>=3.8.3", + "flax>=0.4.1", "hf-doc-builder>=0.3.0", - "huggingface-hub>=0.8.1", + "huggingface-hub>=0.9.1", "importlib_metadata", "isort>=5.5.4", + "jax>=0.2.8,!=0.3.2,<=0.3.6", + "jaxlib>=0.1.65,<=0.3.6", "modelcards==0.1.4", "numpy", "pytest", @@ -167,11 +171,18 @@ extras = {} extras = {} -extras["quality"] = ["black==22.3", "isort>=5.5.4", "flake8>=3.8.3", "hf-doc-builder"] +extras["quality"] = ["black==22.8", "isort>=5.5.4", "flake8>=3.8.3", "hf-doc-builder"] extras["docs"] = ["hf-doc-builder"] extras["training"] = ["accelerate", "datasets", "tensorboard", "modelcards"] -extras["test"] = ["datasets", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "transformers"] -extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"] +extras["test"] = ["datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "transformers"] +extras["torch"] = deps_list("torch") + +if os.name == "nt": # windows + extras["flax"] = [] # jax is not supported on windows +else: + extras["flax"] = deps_list("jax", "jaxlib", "flax") + +extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"] install_requires = [ deps["importlib_metadata"], @@ -180,13 +191,12 @@ install_requires = [ deps["numpy"], deps["regex"], deps["requests"], - deps["torch"], deps["Pillow"], ] setup( name="diffusers", - version="0.3.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) + version="0.4.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) description="Diffusers", long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", @@ -198,7 +208,7 @@ setup( package_dir={"": "src"}, packages=find_packages("src"), include_package_data=True, - python_requires=">=3.6.0", + python_requires=">=3.7.0", install_requires=install_requires, extras_require=extras, entry_points={"console_scripts": ["diffusers-cli=diffusers.commands.diffusers_cli:main"]}, diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 3b37e198ca..ecf4fe5fef 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -1,41 +1,53 @@ -from .utils import is_inflect_available, is_scipy_available, is_transformers_available, is_unidecode_available - - -__version__ = "0.3.0.dev0" - -from .modeling_utils import ModelMixin -from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel -from .optimization import ( - get_constant_schedule, - get_constant_schedule_with_warmup, - get_cosine_schedule_with_warmup, - get_cosine_with_hard_restarts_schedule_with_warmup, - get_linear_schedule_with_warmup, - get_polynomial_decay_schedule_with_warmup, - get_scheduler, -) -from .pipeline_utils import DiffusionPipeline -from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline -from .schedulers import ( - DDIMScheduler, - DDPMScheduler, - KarrasVeScheduler, - PNDMScheduler, - SchedulerMixin, - ScoreSdeVeScheduler, +from .utils import ( + is_flax_available, + is_inflect_available, + is_onnx_available, + is_scipy_available, + is_torch_available, + is_transformers_available, + is_unidecode_available, ) + + +__version__ = "0.4.0.dev0" + +from .configuration_utils import ConfigMixin +from .onnx_utils import OnnxRuntimeModel from .utils import logging -if is_scipy_available(): +if is_torch_available(): + from .modeling_utils import ModelMixin + from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel + from .optimization import ( + get_constant_schedule, + get_constant_schedule_with_warmup, + get_cosine_schedule_with_warmup, + get_cosine_with_hard_restarts_schedule_with_warmup, + get_linear_schedule_with_warmup, + get_polynomial_decay_schedule_with_warmup, + get_scheduler, + ) + from .pipeline_utils import DiffusionPipeline + from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline + from .schedulers import ( + DDIMScheduler, + DDPMScheduler, + KarrasVeScheduler, + PNDMScheduler, + SchedulerMixin, + ScoreSdeVeScheduler, + ) + from .training_utils import EMAModel +else: + from .utils.dummy_pt_objects import * # noqa F403 + +if is_torch_available() and is_scipy_available(): from .schedulers import LMSDiscreteScheduler else: - from .utils.dummy_scipy_objects import * # noqa F403 + from .utils.dummy_torch_and_scipy_objects import * # noqa F403 -from .training_utils import EMAModel - - -if is_transformers_available(): +if is_torch_available() and is_transformers_available(): from .pipelines import ( LDMTextToImagePipeline, StableDiffusionImg2ImgPipeline, @@ -43,4 +55,23 @@ if is_transformers_available(): StableDiffusionPipeline, ) else: - from .utils.dummy_transformers_objects import * # noqa F403 + from .utils.dummy_torch_and_transformers_objects import * # noqa F403 + +if is_torch_available() and is_transformers_available() and is_onnx_available(): + from .pipelines import StableDiffusionOnnxPipeline +else: + from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 + +if is_flax_available(): + from .modeling_flax_utils import FlaxModelMixin + from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel + from .schedulers import ( + FlaxDDIMScheduler, + FlaxDDPMScheduler, + FlaxKarrasVeScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, + FlaxScoreSdeVeScheduler, + ) +else: + from .utils.dummy_flax_objects import * # noqa F403 diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 053ccd6429..f5e5d36ffd 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ ConfigMixinuration base class and utilities.""" +import dataclasses import functools import inspect import json @@ -37,9 +38,16 @@ _re_configuration_file = re.compile(r"config\.(.*)\.json") class ConfigMixin: r""" - Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as - methods for loading/downloading/saving configurations. + Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all + methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with + - [`~ConfigMixin.from_config`] + - [`~ConfigMixin.save_config`] + Class attributes: + - **config_name** (`str`) -- A filename under which the config should stored when calling + [`~ConfigMixin.save_config`] (should be overridden by parent class). + - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be + overridden by parent class). """ config_name = None ignore_for_config = [] @@ -74,8 +82,6 @@ class ConfigMixin: Args: save_directory (`str` or `os.PathLike`): Directory where the configuration JSON file will be saved (will be created if it does not exist). - kwargs: - Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ if os.path.isfile(save_directory): raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") @@ -90,6 +96,63 @@ class ConfigMixin: @classmethod def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs): + r""" + Instantiate a Python class from a pre-defined JSON-file. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an + organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g., + `./my_model_directory/`. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): + Whether or not to raise an error if some of the weights from the checkpoint do not have the same size + as the weights of the model (if for instance, you are instantiating a model with 10 labels from a + checkpoint with 3 labels). + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + + + + Passing `use_auth_token=True`` is required when you want to use a private model. + + + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + + """ config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) @@ -112,6 +175,7 @@ class ConfigMixin: use_auth_token = kwargs.pop("use_auth_token", None) local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) + _ = kwargs.pop("mirror", None) subfolder = kwargs.pop("subfolder", None) user_agent = {"file_type": "config"} @@ -208,6 +272,11 @@ class ConfigMixin: # remove general kwargs if present in dict if "kwargs" in expected_keys: expected_keys.remove("kwargs") + # remove flax interal keys + if hasattr(cls, "_flax_internal_args"): + for arg in cls._flax_internal_args: + expected_keys.remove(arg) + # remove keys to be ignored if len(cls.ignore_for_config) > 0: expected_keys = expected_keys - set(cls.ignore_for_config) @@ -220,11 +289,20 @@ class ConfigMixin: # use value from config dict init_dict[key] = config_dict.pop(key) - unused_kwargs = config_dict.update(kwargs) + config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")} + + if len(config_dict) > 0: + logger.warning( + f"The config attributes {config_dict} were passed to {cls.__name__}, " + "but are not expected and will be ignored. Please verify your " + f"{cls.config_name} configuration file." + ) + + unused_kwargs = {**config_dict, **kwargs} passed_keys = set(init_dict.keys()) if len(expected_keys - passed_keys) > 0: - logger.warning( + logger.info( f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values." ) @@ -298,10 +376,10 @@ class FrozenDict(OrderedDict): def register_to_config(init): - """ - Decorator to apply on the init of classes inheriting from `ConfigMixin` so that all the arguments are automatically - sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that shouldn't be - registered in the config, use the `ignore_for_config` class variable + r""" + Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are + automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that + shouldn't be registered in the config, use the `ignore_for_config` class variable Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init! """ @@ -338,3 +416,44 @@ def register_to_config(init): getattr(self, "register_to_config")(**new_kwargs) return inner_init + + +def flax_register_to_config(cls): + original_init = cls.__init__ + + @functools.wraps(original_init) + def init(self, *args, **kwargs): + if not isinstance(self, ConfigMixin): + raise RuntimeError( + f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " + "not inherit from `ConfigMixin`." + ) + + # Ignore private kwargs in the init. Retrieve all passed attributes + init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + + # Retrieve default values + fields = dataclasses.fields(self) + default_kwargs = {} + for field in fields: + # ignore flax specific attributes + if field.name in self._flax_internal_args: + continue + if type(field.default) == dataclasses._MISSING_TYPE: + default_kwargs[field.name] = None + else: + default_kwargs[field.name] = getattr(self, field.name) + + # Make sure init_kwargs override default kwargs + new_kwargs = {**default_kwargs, **init_kwargs} + + # Get positional arguments aligned with kwargs + for i, arg in enumerate(args): + name = fields[i].name + new_kwargs[name] = arg + + getattr(self, "register_to_config")(**new_kwargs) + original_init(self, *args, **kwargs) + + cls.__init__ = init + return cls diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 74c5331e5a..f6fb397303 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -4,14 +4,17 @@ deps = { "Pillow": "Pillow", "accelerate": "accelerate>=0.11.0", - "black": "black==22.3", + "black": "black==22.8", "datasets": "datasets", "filelock": "filelock", "flake8": "flake8>=3.8.3", + "flax": "flax>=0.4.1", "hf-doc-builder": "hf-doc-builder>=0.3.0", - "huggingface-hub": "huggingface-hub>=0.8.1", + "huggingface-hub": "huggingface-hub>=0.9.1", "importlib_metadata": "importlib_metadata", "isort": "isort>=5.5.4", + "jax": "jax>=0.2.8,!=0.3.2,<=0.3.6", + "jaxlib": "jaxlib>=0.1.65,<=0.3.6", "modelcards": "modelcards==0.1.4", "numpy": "numpy", "pytest": "pytest", diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py new file mode 100644 index 0000000000..f195462eca --- /dev/null +++ b/src/diffusers/modeling_flax_utils.py @@ -0,0 +1,526 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from pickle import UnpicklingError +from typing import Any, Dict, Union + +import jax +import jax.numpy as jnp +import msgpack.exceptions +from flax.core.frozen_dict import FrozenDict, unfreeze +from flax.serialization import from_bytes, to_bytes +from flax.traverse_util import flatten_dict, unflatten_dict +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError +from requests import HTTPError + +from .modeling_utils import WEIGHTS_NAME +from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging + + +FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" + +logger = logging.get_logger(__name__) + + +class FlaxModelMixin: + r""" + Base class for all flax models. + + [`FlaxModelMixin`] takes care of storing the configuration of the models and handles methods for loading, + downloading and saving models. + """ + config_name = CONFIG_NAME + _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] + _flax_internal_args = ["name", "parent"] + + @classmethod + def _from_config(cls, config, **kwargs): + """ + All context managers that the model should be initialized under go here. + """ + return cls(config, **kwargs) + + def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: + """ + Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. + """ + + # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 + def conditional_cast(param): + if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): + param = param.astype(dtype) + return param + + if mask is None: + return jax.tree_map(conditional_cast, params) + + flat_params = flatten_dict(params) + flat_mask, _ = jax.tree_flatten(mask) + + for masked, key in zip(flat_mask, flat_params.keys()): + if masked: + param = flat_params[key] + flat_params[key] = conditional_cast(param) + + return unflatten_dict(flat_params) + + def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None): + r""" + Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast + the `params` in place. + + This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full + half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed. + + Arguments: + params (`Union[Dict, FrozenDict]`): + A `PyTree` of model parameters. + mask (`Union[Dict, FrozenDict]`): + A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params + you want to cast, and should be `False` for those you want to skip. + + Examples: + + ```python + >>> from diffusers import FlaxUNet2DConditionModel + + >>> # load model + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") + >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision + >>> params = model.to_bf16(params) + >>> # If you don't want to cast certain parameters (for example layer norm bias and scale) + >>> # then pass the mask as follows + >>> from flax import traverse_util + + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") + >>> flat_params = traverse_util.flatten_dict(params) + >>> mask = { + ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) + ... for path in flat_params + ... } + >>> mask = traverse_util.unflatten_dict(mask) + >>> params = model.to_bf16(params, mask) + ```""" + return self._cast_floating_to(params, jnp.bfloat16, mask) + + def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None): + r""" + Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the + model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place. + + Arguments: + params (`Union[Dict, FrozenDict]`): + A `PyTree` of model parameters. + mask (`Union[Dict, FrozenDict]`): + A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params + you want to cast, and should be `False` for those you want to skip + + Examples: + + ```python + >>> from diffusers import FlaxUNet2DConditionModel + + >>> # Download model and configuration from huggingface.co + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") + >>> # By default, the model params will be in fp32, to illustrate the use of this method, + >>> # we'll first cast to fp16 and back to fp32 + >>> params = model.to_f16(params) + >>> # now cast back to fp32 + >>> params = model.to_fp32(params) + ```""" + return self._cast_floating_to(params, jnp.float32, mask) + + def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None): + r""" + Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the + `params` in place. + + This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full + half-precision training or to save weights in float16 for inference in order to save memory and improve speed. + + Arguments: + params (`Union[Dict, FrozenDict]`): + A `PyTree` of model parameters. + mask (`Union[Dict, FrozenDict]`): + A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params + you want to cast, and should be `False` for those you want to skip + + Examples: + + ```python + >>> from diffusers import FlaxUNet2DConditionModel + + >>> # load model + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") + >>> # By default, the model params will be in fp32, to cast these to float16 + >>> params = model.to_fp16(params) + >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) + >>> # then pass the mask as follows + >>> from flax import traverse_util + + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") + >>> flat_params = traverse_util.flatten_dict(params) + >>> mask = { + ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) + ... for path in flat_params + ... } + >>> mask = traverse_util.unflatten_dict(mask) + >>> params = model.to_fp16(params, mask) + ```""" + return self._cast_floating_to(params, jnp.float16, mask) + + def init_weights(self, rng: jax.random.PRNGKey) -> Dict: + raise NotImplementedError(f"init_weights method has to be implemented for {self}") + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + dtype: jnp.dtype = jnp.float32, + *model_args, + **kwargs, + ): + r""" + Instantiate a pretrained flax model from a pre-trained model configuration. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids are namespaced under a user or organization name, like + `CompVis/stable-diffusion-v1-4`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`], + e.g., `./my_model_directory/`. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~ModelMixin.to_fp16`] and + [`~ModelMixin.to_bf16`]. + model_args (sequence of positional arguments, *optional*): + All remaining positional arguments will be passed to the underlying model's `__init__` method. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). Behaves differently depending on whether a `config` is provided or + automatically loaded: + + - If a configuration is provided with `config`, `**kwargs` will be directly passed to the + underlying model's `__init__` method (we assume all relevant updates to the configuration have + already been done) + - If a configuration is not provided, `kwargs` will be first passed to the configuration class + initialization function ([`~ConfigMixin.from_config`]). Each key of `kwargs` that corresponds to + a configuration attribute will be used to override said attribute with the supplied `kwargs` + value. Remaining keys that do not correspond to any configuration attribute will be passed to the + underlying model's `__init__` function. + + Examples: + + ```python + >>> from diffusers import FlaxUNet2DConditionModel + + >>> # Download model and configuration from huggingface.co and cache. + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") + >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/") + ```""" + config = kwargs.pop("config", None) + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + from_auto_class = kwargs.pop("_from_auto", False) + subfolder = kwargs.pop("subfolder", None) + + user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class} + + # Load config if we don't provide a configuration + config_path = config if config is not None else pretrained_model_name_or_path + model, model_kwargs = cls.from_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + # model args + dtype=dtype, + **kwargs, + ) + + # Load model + if os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): + # Load from a Flax checkpoint + model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) + # At this stage we don't have a weight file so we will raise an error. + elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): + raise EnvironmentError( + f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " + "but there is a file for PyTorch weights." + ) + else: + raise EnvironmentError( + f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " + f"{pretrained_model_name_or_path}." + ) + else: + try: + model_file = hf_hub_download( + pretrained_model_name_or_path, + filename=FLAX_WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision, + ) + + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " + "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " + "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " + "login` and pass `use_auth_token=True`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " + "this model name. Check the model page at " + f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}." + ) + except HTTPError as err: + raise EnvironmentError( + f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n" + f"{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your" + " internet connection or see how to run the library in offline mode at" + " 'https://huggingface.co/docs/transformers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}." + ) + + try: + with open(model_file, "rb") as state_f: + state = from_bytes(cls, state_f.read()) + except (UnpicklingError, msgpack.exceptions.ExtraData) as e: + try: + with open(model_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please" + " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" + " folder you cloned." + ) + else: + raise ValueError from e + except (UnicodeDecodeError, ValueError): + raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") + # make sure all arrays are stored as jnp.ndarray + # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: + # https://github.com/google/flax/issues/1261 + state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state) + + # flatten dicts + state = flatten_dict(state) + + params_shape_tree = jax.eval_shape(model.init_weights, rng=jax.random.PRNGKey(0)) + required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys()) + + shape_state = flatten_dict(unfreeze(params_shape_tree)) + + missing_keys = required_params - set(state.keys()) + unexpected_keys = set(state.keys()) - required_params + + if missing_keys: + logger.warning( + f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. " + "Make sure to call model.init_weights to initialize the missing weights." + ) + cls._missing_keys = missing_keys + + # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not + # matching the weights in the model. + mismatched_keys = [] + for key in state.keys(): + if key in shape_state and state[key].shape != shape_state[key].shape: + raise ValueError( + f"Trying to load the pretrained weight for {key} failed: checkpoint has shape " + f"{state[key].shape} which is incompatible with the model shape {shape_state[key].shape}. " + ) + + # remove unexpected keys to not be saved again + for unexpected_key in unexpected_keys: + del state[unexpected_key] + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" + " with another architecture." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" + f" was trained on, you can already use {model.__class__.__name__} for predictions without further" + " training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" + " to use it for predictions and inference." + ) + + # dictionary of key: dtypes for the model params + param_dtypes = jax.tree_map(lambda x: x.dtype, state) + # extract keys of parameters not in jnp.float32 + fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16] + bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16] + + # raise a warning if any of the parameters are not in jnp.float32 + if len(fp16_params) > 0: + logger.warning( + f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from " + f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n" + "You should probably UPCAST the model weights to float32 if this was not intended. " + "See [`~ModelMixin.to_fp32`] for further information on how to do this." + ) + + if len(bf16_params) > 0: + logger.warning( + f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from " + f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n" + "You should probably UPCAST the model weights to float32 if this was not intended. " + "See [`~ModelMixin.to_fp32`] for further information on how to do this." + ) + + return model, unflatten_dict(state) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + params: Union[Dict, FrozenDict], + is_main_process: bool = True, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + `[`~FlaxModelMixin.from_pretrained`]` class method + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + params (`Union[Dict, FrozenDict]`): + A `PyTree` of model parameters. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + """ + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + model_to_save = self + + # Attach architecture to the config + # Save the config + if is_main_process: + model_to_save.save_config(save_directory) + + # save model + output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME) + with open(output_model_file, "wb") as f: + model_bytes = to_bytes(params) + f.write(model_bytes) + + logger.info(f"Model weights saved in {output_model_file}") diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index ec501e2ae1..ef1ead9ecf 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -117,27 +117,10 @@ class ModelMixin(torch.nn.Module): Base class for all models. [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading - and saving models as well as a few methods common to all models to: + and saving models. - - resize the input embeddings, - - prune heads in the self-attention heads. - - Class attributes (overridden by derived classes): - - - **config_class** ([`ConfigMixin`]) -- A subclass of [`ConfigMixin`] to use as configuration class for this - model architecture. - - **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model, - taking as arguments: - - - **model** ([`ModelMixin`]) -- An instance of the model on which to load the TensorFlow checkpoint. - - **config** ([`PreTrainedConfigMixin`]) -- An instance of the configuration associated to the model. - - **path** (`str`) -- A path to the TensorFlow checkpoint. - - - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived - classes of the same architecture adding modules on top of the base model. - - **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization. - - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP - models, `pixel_values` for vision models and `input_values` for speech models). + - **config_name** ([`str`]) -- A filename under which the model should be stored when calling + [`~modeling_utils.ModelMixin.save_pretrained`]. """ config_name = CONFIG_NAME _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] @@ -150,11 +133,10 @@ class ModelMixin(torch.nn.Module): save_directory: Union[str, os.PathLike], is_main_process: bool = True, save_function: Callable = torch.save, - **kwargs, ): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the - `[`~ModelMixin.from_pretrained`]` class method. + `[`~modeling_utils.ModelMixin.from_pretrained`]` class method. Arguments: save_directory (`str` or `os.PathLike`): @@ -166,9 +148,6 @@ class ModelMixin(torch.nn.Module): save_function (`Callable`): The function to use to save the state dictionary. Useful on distributed training like TPUs when one need to replace `torch.save` by another method. - - kwargs: - Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") @@ -219,39 +198,16 @@ class ModelMixin(torch.nn.Module): Can be either: - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a - user or organization name, like `dbmdz/bert-base-german-cased`. - - A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`], - e.g., `./my_model_directory/`. + Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., + `./my_model_directory/`. - config (`Union[ConfigMixin, str, os.PathLike]`, *optional*): - Can be either: - - - an instance of a class derived from [`ConfigMixin`], - - a string or path valid as input to [`~ConfigMixin.from_pretrained`]. - - ConfigMixinuration for the model to use instead of an automatically loaded configuration. - ConfigMixinuration can be automatically loaded when: - - - The model is a model provided by the library (loaded with the *model id* string of a pretrained - model). - - The model was saved using [`~ModelMixin.save_pretrained`] and is reloaded by supplying the save - directory. - - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a - configuration JSON file named *config.json* is found in the directory. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_tf (`bool`, *optional*, defaults to `False`): - Load the model weights from a TensorFlow checkpoint save file (see docstring of - `pretrained_model_name_or_path` argument). - from_flax (`bool`, *optional*, defaults to `False`): - Load the model weights from a Flax checkpoint save file (see docstring of - `pretrained_model_name_or_path` argument). - ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): - Whether or not to raise an error if some of the weights from the checkpoint do not have the same size - as the weights of the model (if for instance, you are instantiating a model with 10 labels from a - checkpoint with 3 labels). + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. @@ -262,35 +218,25 @@ class ModelMixin(torch.nn.Module): A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(`bool`, *optional*, defaults to `False`): - Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only(`bool`, *optional*, defaults to `False`): Whether or not to only look at local files (i.e., do not try to download the model). use_auth_token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `transformers-cli login` (stored in `~/.huggingface`). + when running `diffusers-cli login` (stored in `~/.huggingface`). revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + mirror (`str`, *optional*): Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please refer to the mirror site for more information. - kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., - `output_attentions=True`). Behaves differently depending on whether a `config` is provided or - automatically loaded: - - - If a configuration is provided with `config`, `**kwargs` will be directly passed to the - underlying model's `__init__` method (we assume all relevant updates to the configuration have - already been done) - - If a configuration is not provided, `kwargs` will be first passed to the configuration class - initialization function ([`~ConfigMixin.from_pretrained`]). Each key of `kwargs` that corresponds - to a configuration attribute will be used to override said attribute with the supplied `kwargs` - value. Remaining keys that do not correspond to any configuration attribute will be passed to the - underlying model's `__init__` function. - Passing `use_auth_token=True`` is required when you want to use a private model. @@ -299,8 +245,8 @@ class ModelMixin(torch.nn.Module): - Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to - use this method in a firewalled environment. + Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use + this method in a firewalled environment. @@ -404,7 +350,7 @@ class ModelMixin(torch.nn.Module): f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" f" directory containing a file named {WEIGHTS_NAME} or" " \nCheckout your internet connection or see how to run the library in" - " offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'." + " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." ) except EnvironmentError: raise EnvironmentError( diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index a69d9014bd..e4cedbff8c 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -1,4 +1,5 @@ import math +from typing import Optional import torch import torch.nn.functional as F @@ -10,16 +11,24 @@ class AttentionBlock(nn.Module): An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted to the N-d case. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - Uses three q, k, v linear layers to compute attention + Uses three q, k, v linear layers to compute attention. + + Parameters: + channels (:obj:`int`): The number of channels in the input and output. + num_head_channels (:obj:`int`, *optional*): + The number of channels in each head. If None, then `num_heads` = 1. + num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm. + rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by. + eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. """ def __init__( self, - channels, - num_head_channels=None, - num_groups=32, - rescale_output_factor=1.0, - eps=1e-5, + channels: int, + num_head_channels: Optional[int] = None, + num_groups: int = 32, + rescale_output_factor: float = 1.0, + eps: float = 1e-5, ): super().__init__() self.channels = channels @@ -86,16 +95,33 @@ class AttentionBlock(nn.Module): class SpatialTransformer(nn.Module): """ Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply - standard transformer action. Finally, reshape to image + standard transformer action. Finally, reshape to image. + + Parameters: + in_channels (:obj:`int`): The number of channels in the input and output. + n_heads (:obj:`int`): The number of heads to use for multi-head attention. + d_head (:obj:`int`): The number of channels in each head. + depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use. + context_dim (:obj:`int`, *optional*): The number of context dimensions to use. """ - def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None): + def __init__( + self, + in_channels: int, + n_heads: int, + d_head: int, + depth: int = 1, + dropout: float = 0.0, + num_groups: int = 32, + context_dim: Optional[int] = None, + ): super().__init__() self.n_heads = n_heads self.d_head = d_head self.in_channels = in_channels inner_dim = n_heads * d_head - self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) @@ -112,22 +138,44 @@ class SpatialTransformer(nn.Module): for block in self.transformer_blocks: block._set_attention_slice(slice_size) - def forward(self, x, context=None): + def forward(self, hidden_states, context=None): # note: if no context is given, cross-attention defaults to self-attention - b, c, h, w = x.shape - x_in = x - x = self.norm(x) - x = self.proj_in(x) - x = x.permute(0, 2, 3, 1).reshape(b, h * w, c) + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + hidden_states = self.norm(hidden_states) + hidden_states = self.proj_in(hidden_states) + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel) for block in self.transformer_blocks: - x = block(x, context=context) - x = x.reshape(b, h, w, c).permute(0, 3, 1, 2) - x = self.proj_out(x) - return x + x_in + hidden_states = block(hidden_states, context=context) + hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2) + hidden_states = self.proj_out(hidden_states) + return hidden_states + residual class BasicTransformerBlock(nn.Module): - def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True): + r""" + A basic Transformer block. + + Parameters: + dim (:obj:`int`): The number of channels in the input and output. + n_heads (:obj:`int`): The number of heads to use for multi-head attention. + d_head (:obj:`int`): The number of channels in each head. + dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention. + gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network. + checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing. + """ + + def __init__( + self, + dim: int, + n_heads: int, + d_head: int, + dropout=0.0, + context_dim: Optional[int] = None, + gated_ff: bool = True, + checkpoint: bool = True, + ): super().__init__() self.attn1 = CrossAttention( query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout @@ -145,15 +193,30 @@ class BasicTransformerBlock(nn.Module): self.attn1._slice_size = slice_size self.attn2._slice_size = slice_size - def forward(self, x, context=None): - x = self.attn1(self.norm1(x)) + x - x = self.attn2(self.norm2(x), context=context) + x - x = self.ff(self.norm3(x)) + x - return x + def forward(self, hidden_states, context=None): + hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states + hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states + hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + return hidden_states class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + r""" + A cross attention layer. + + Parameters: + query_dim (:obj:`int`): The number of channels in the query. + context_dim (:obj:`int`, *optional*): + The number of channels in the context. If not given, defaults to `query_dim`. + heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + """ + + def __init__( + self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0 + ): super().__init__() inner_dim = dim_head * heads context_dim = context_dim if context_dim is not None else query_dim @@ -185,26 +248,39 @@ class CrossAttention(nn.Module): tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor - def forward(self, x, context=None, mask=None): - batch_size, sequence_length, dim = x.shape + def forward(self, hidden_states, context=None, mask=None): + batch_size, sequence_length, dim = hidden_states.shape - q = self.to_q(x) - context = context if context is not None else x - k = self.to_k(context) - v = self.to_v(context) + query = self.to_q(hidden_states) + context = context if context is not None else hidden_states + key = self.to_k(context) + value = self.to_v(context) - q = self.reshape_heads_to_batch_dim(q) - k = self.reshape_heads_to_batch_dim(k) - v = self.reshape_heads_to_batch_dim(v) + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) # TODO(PVP) - mask is currently never used. Remember to re-implement when used # attention, what we cannot get enough of - hidden_states = self._attention(q, k, v, sequence_length, dim) + + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) return self.to_out(hidden_states) - def _attention(self, query, key, value, sequence_length, dim): + def _attention(self, query, key, value): + attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale + attention_probs = attention_scores.softmax(dim=-1) + # compute attention output + hidden_states = torch.matmul(attention_probs, value) + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim): batch_size_attention = query.shape[0] hidden_states = torch.zeros( (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype @@ -213,11 +289,9 @@ class CrossAttention(nn.Module): for i in range(hidden_states.shape[0] // slice_size): start_idx = i * slice_size end_idx = (i + 1) * slice_size - attn_slice = ( - torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale - ) + attn_slice = torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale attn_slice = attn_slice.softmax(dim=-1) - attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx]) + attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice @@ -227,7 +301,20 @@ class CrossAttention(nn.Module): class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + r""" + A feed-forward layer. + + Parameters: + dim (:obj:`int`): The number of channels in the input. + dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation. + dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + """ + + def __init__( + self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0 + ): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim @@ -235,16 +322,24 @@ class FeedForward(nn.Module): self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) - def forward(self, x): - return self.net(x) + def forward(self, hidden_states): + return self.net(hidden_states) # feedforward class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (:obj:`int`): The number of channels in the input. + dim_out (:obj:`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) - def forward(self, x): - x, gate = self.proj(x).chunk(2, dim=-1) - return x * F.gelu(gate) + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * F.gelu(gate) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py new file mode 100644 index 0000000000..918c7469a7 --- /dev/null +++ b/src/diffusers/models/attention_flax.py @@ -0,0 +1,180 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import flax.linen as nn +import jax.numpy as jnp + + +class FlaxAttentionBlock(nn.Module): + query_dim: int + heads: int = 8 + dim_head: int = 64 + dropout: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + inner_dim = self.dim_head * self.heads + self.scale = self.dim_head**-0.5 + + # Weights were exported with old names {to_q, to_k, to_v, to_out} + self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q") + self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k") + self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v") + + self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out") + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) + tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) + tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def __call__(self, hidden_states, context=None, deterministic=True): + context = hidden_states if context is None else context + + query_proj = self.query(hidden_states) + key_proj = self.key(context) + value_proj = self.value(context) + + query_states = self.reshape_heads_to_batch_dim(query_proj) + key_states = self.reshape_heads_to_batch_dim(key_proj) + value_states = self.reshape_heads_to_batch_dim(value_proj) + + # compute attentions + attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) + attention_scores = attention_scores * self.scale + attention_probs = nn.softmax(attention_scores, axis=2) + + # attend to values + hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + hidden_states = self.proj_attn(hidden_states) + return hidden_states + + +class FlaxBasicTransformerBlock(nn.Module): + dim: int + n_heads: int + d_head: int + dropout: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + # self attention + self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) + # cross attention + self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) + self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) + self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) + self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) + self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) + + def __call__(self, hidden_states, context, deterministic=True): + # self attention + residual = hidden_states + hidden_states = self.self_attn(self.norm1(hidden_states), deterministic=deterministic) + hidden_states = hidden_states + residual + + # cross attention + residual = hidden_states + hidden_states = self.cross_attn(self.norm2(hidden_states), context, deterministic=deterministic) + hidden_states = hidden_states + residual + + # feed forward + residual = hidden_states + hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic) + hidden_states = hidden_states + residual + + return hidden_states + + +class FlaxSpatialTransformer(nn.Module): + in_channels: int + n_heads: int + d_head: int + depth: int = 1 + dropout: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) + + inner_dim = self.n_heads * self.d_head + self.proj_in = nn.Conv( + inner_dim, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + + self.transformer_blocks = [ + FlaxBasicTransformerBlock(inner_dim, self.n_heads, self.d_head, dropout=self.dropout, dtype=self.dtype) + for _ in range(self.depth) + ] + + self.proj_out = nn.Conv( + inner_dim, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + + def __call__(self, hidden_states, context, deterministic=True): + batch, height, width, channels = hidden_states.shape + # import ipdb; ipdb.set_trace() + residual = hidden_states + hidden_states = self.norm(hidden_states) + hidden_states = self.proj_in(hidden_states) + + hidden_states = hidden_states.reshape(batch, height * width, channels) + + for transformer_block in self.transformer_blocks: + hidden_states = transformer_block(hidden_states, context, deterministic=deterministic) + + hidden_states = hidden_states.reshape(batch, height, width, channels) + + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states + residual + + return hidden_states + + +class FlaxGluFeedForward(nn.Module): + dim: int + dropout: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + inner_dim = self.dim * 4 + self.dense1 = nn.Dense(inner_dim * 2, dtype=self.dtype) + self.dense2 = nn.Dense(self.dim, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic=True): + hidden_states = self.dense1(hidden_states) + hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2) + hidden_states = hidden_linear * nn.gelu(hidden_gelu) + hidden_states = self.dense2(hidden_states) + return hidden_states diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 99bfa96f0d..d8a6cf105a 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -19,7 +19,12 @@ from torch import nn def get_timestep_embedding( - timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, scale=1, max_period=10000 + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, ): """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. @@ -55,7 +60,7 @@ def get_timestep_embedding( class TimestepEmbedding(nn.Module): - def __init__(self, channel, time_embed_dim, act_fn="silu"): + def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"): super().__init__() self.linear_1 = nn.Linear(channel, time_embed_dim) @@ -75,7 +80,7 @@ class TimestepEmbedding(nn.Module): class Timesteps(nn.Module): - def __init__(self, num_channels, flip_sin_to_cos, downscale_freq_shift): + def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): super().__init__() self.num_channels = num_channels self.flip_sin_to_cos = flip_sin_to_cos @@ -94,7 +99,7 @@ class Timesteps(nn.Module): class GaussianFourierProjection(nn.Module): """Gaussian Fourier embeddings for noise levels.""" - def __init__(self, embedding_size=256, scale=1.0): + def __init__(self, embedding_size: int = 256, scale: float = 1.0): super().__init__() self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py new file mode 100644 index 0000000000..63442ab997 --- /dev/null +++ b/src/diffusers/models/embeddings_flax.py @@ -0,0 +1,56 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import flax.linen as nn +import jax.numpy as jnp + + +# This is like models.embeddings.get_timestep_embedding (PyTorch) but +# less general (only handles the case we currently need). +def get_sinusoidal_embeddings(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] tensor of positional embeddings. + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = jnp.exp(jnp.arange(half_dim) * -emb) + emb = timesteps[:, None] * emb[None, :] + emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1) + return emb + + +class FlaxTimestepEmbedding(nn.Module): + time_embed_dim: int = 32 + dtype: jnp.dtype = jnp.float32 + + @nn.compact + def __call__(self, temb): + temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb) + temb = nn.silu(temb) + temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb) + return temb + + +class FlaxTimesteps(nn.Module): + dim: int = 32 + + @nn.compact + def __call__(self, timesteps): + return get_sinusoidal_embeddings(timesteps, self.dim) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index f98d11e417..785a4b9135 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -1,6 +1,5 @@ from functools import partial -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -134,10 +133,10 @@ class FirUpsample2D(nn.Module): kernel = [1] * factor # setup kernel - kernel = np.asarray(kernel, dtype=np.float32) + kernel = torch.tensor(kernel, dtype=torch.float32) if kernel.ndim == 1: - kernel = np.outer(kernel, kernel) - kernel /= np.sum(kernel) + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) kernel = kernel * (gain * (factor**2)) @@ -219,10 +218,10 @@ class FirDownsample2D(nn.Module): kernel = [1] * factor # setup kernel - kernel = np.asarray(kernel, dtype=np.float32) + kernel = torch.tensor(kernel, dtype=torch.float32) if kernel.ndim == 1: - kernel = np.outer(kernel, kernel) - kernel /= np.sum(kernel) + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) kernel = kernel * gain @@ -265,7 +264,7 @@ class ResnetBlock2D(nn.Module): time_embedding_norm="default", kernel=None, output_scale_factor=1.0, - use_nin_shortcut=None, + use_in_shortcut=None, up=False, down=False, ): @@ -322,10 +321,10 @@ class ResnetBlock2D(nn.Module): else: self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") - self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut + self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut self.conv_shortcut = None - if self.use_nin_shortcut: + if self.use_in_shortcut: self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x, temb): @@ -391,16 +390,14 @@ def upsample_2d(x, kernel=None, factor=2, gain=1): if kernel is None: kernel = [1] * factor - kernel = np.asarray(kernel, dtype=np.float32) + kernel = torch.tensor(kernel, dtype=torch.float32) if kernel.ndim == 1: - kernel = np.outer(kernel, kernel) - kernel /= np.sum(kernel) + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) kernel = kernel * (gain * (factor**2)) p = kernel.shape[0] - factor - return upfirdn2d_native( - x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2) - ) + return upfirdn2d_native(x, kernel.to(device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)) def downsample_2d(x, kernel=None, factor=2, gain=1): @@ -425,14 +422,14 @@ def downsample_2d(x, kernel=None, factor=2, gain=1): if kernel is None: kernel = [1] * factor - kernel = np.asarray(kernel, dtype=np.float32) + kernel = torch.tensor(kernel, dtype=torch.float32) if kernel.ndim == 1: - kernel = np.outer(kernel, kernel) - kernel /= np.sum(kernel) + kernel = torch.outer(kernel, kernel) + kernel /= torch.sum(kernel) kernel = kernel * gain p = kernel.shape[0] - factor - return upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) + return upfirdn2d_native(x, kernel.to(device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): @@ -448,10 +445,15 @@ def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): kernel_h, kernel_w = kernel.shape out = input.view(-1, in_h, 1, in_w, 1, minor) + + # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535 + if input.device.type == "mps": + out = out.to("cpu") out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) out = out.view(-1, in_h * up_y, in_w * up_x, minor) out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out.to(input.device) # Move back to mps if necessary out = out[ :, max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), diff --git a/src/diffusers/models/resnet_flax.py b/src/diffusers/models/resnet_flax.py new file mode 100644 index 0000000000..46ccee35ad --- /dev/null +++ b/src/diffusers/models/resnet_flax.py @@ -0,0 +1,111 @@ +import flax.linen as nn +import jax +import jax.numpy as jnp + + +class FlaxUpsample2D(nn.Module): + out_channels: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv = nn.Conv( + self.out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + batch, height, width, channels = hidden_states.shape + hidden_states = jax.image.resize( + hidden_states, + shape=(batch, height * 2, width * 2, channels), + method="nearest", + ) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class FlaxDownsample2D(nn.Module): + out_channels: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv = nn.Conv( + self.out_channels, + kernel_size=(3, 3), + strides=(2, 2), + padding=((1, 1), (1, 1)), # padding="VALID", + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + # pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim + # hidden_states = jnp.pad(hidden_states, pad_width=pad) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class FlaxResnetBlock2D(nn.Module): + in_channels: int + out_channels: int = None + dropout_prob: float = 0.0 + use_nin_shortcut: bool = None + dtype: jnp.dtype = jnp.float32 + + def setup(self): + out_channels = self.in_channels if self.out_channels is None else self.out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5) + self.conv1 = nn.Conv( + out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype) + + self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5) + self.dropout = nn.Dropout(self.dropout_prob) + self.conv2 = nn.Conv( + out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut + + self.conv_shortcut = None + if use_nin_shortcut: + self.conv_shortcut = nn.Conv( + out_channels, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + + def __call__(self, hidden_states, temb, deterministic=True): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = nn.swish(hidden_states) + hidden_states = self.conv1(hidden_states) + + temb = self.time_emb_proj(nn.swish(temb)) + temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1) + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + hidden_states = nn.swish(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + return hidden_states + residual diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 46d5ee5329..89321a5503 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -23,6 +23,38 @@ class UNet2DOutput(BaseOutput): class UNet2DModel(ModelMixin, ConfigMixin): + r""" + UNet2DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + sample_size (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*): + Input sample size. + in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. + out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use. + freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding. + flip_sin_to_cos (`bool`, *optional*, defaults to : + obj:`False`): Whether to flip sin to cos for fourier time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block + types. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(224, 448, 672, 896)`): Tuple of block output channels. + layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block. + mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block. + downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. + norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization. + norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization. + """ + @register_to_config def __init__( self, @@ -82,6 +114,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, attn_num_head_channels=attention_head_dim, downsample_padding=downsample_padding, ) @@ -119,6 +152,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): add_upsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, attn_num_head_channels=attention_head_dim, ) self.up_blocks.append(up_block) @@ -136,6 +170,17 @@ class UNet2DModel(ModelMixin, ConfigMixin): timestep: Union[torch.Tensor, float, int], return_dict: bool = True, ) -> Union[UNet2DOutput, Tuple]: + """r + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True, + otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 10e2ece546..42b54657d2 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -23,6 +23,37 @@ class UNet2DConditionOutput(BaseOutput): class UNet2DConditionModel(ModelMixin, ConfigMixin): + r""" + UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep + and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + Parameters: + sample_size (`int`, *optional*): The size of the input sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + """ + @register_to_config def __init__( self, @@ -83,6 +114,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim, downsample_padding=downsample_padding, @@ -122,6 +154,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): add_upsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attention_head_dim, ) @@ -162,6 +195,19 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): encoder_hidden_states: torch.Tensor, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: + """r + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py new file mode 100644 index 0000000000..636a7ef981 --- /dev/null +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -0,0 +1,258 @@ +from typing import Tuple, Union + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict + +from ..configuration_utils import ConfigMixin, flax_register_to_config +from ..modeling_flax_utils import FlaxModelMixin +from ..utils import BaseOutput +from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps +from .unet_blocks_flax import ( + FlaxCrossAttnDownBlock2D, + FlaxCrossAttnUpBlock2D, + FlaxDownBlock2D, + FlaxUNetMidBlock2DCrossAttn, + FlaxUpBlock2D, +) + + +@flax.struct.dataclass +class FlaxUNet2DConditionOutput(BaseOutput): + """ + Args: + sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): + Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: jnp.ndarray + + +@flax_register_to_config +class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): + r""" + FlaxUNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a + timestep and returns sample shaped output. + + This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + Parameters: + sample_size (`int`, *optional*): The size of the input sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D", + "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D" + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + The tuple of upsample blocks to use. The corresponding class names will be: "FlaxUpBlock2D", + "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D" + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + cross_attention_dim (`int`, *optional*, defaults to 768): The dimension of the cross attention features. + dropout (`float`, *optional*, defaults to 0): Dropout probability for down, up and bottleneck blocks. + """ + + sample_size: int = 32 + in_channels: int = 4 + out_channels: int = 4 + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ) + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") + block_out_channels: Tuple[int] = (320, 640, 1280, 1280) + layers_per_block: int = 2 + attention_head_dim: int = 8 + cross_attention_dim: int = 1280 + dropout: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: + # init input tensors + sample_shape = (1, self.sample_size, self.sample_size, self.in_channels) + sample = jnp.zeros(sample_shape, dtype=jnp.float32) + timesteps = jnp.ones((1,), dtype=jnp.int32) + encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"] + + def setup(self): + block_out_channels = self.block_out_channels + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = nn.Conv( + block_out_channels[0], + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + # time + self.time_proj = FlaxTimesteps(block_out_channels[0]) + self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) + + # down + down_blocks = [] + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(self.down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + if down_block_type == "CrossAttnDownBlock2D": + down_block = FlaxCrossAttnDownBlock2D( + in_channels=input_channel, + out_channels=output_channel, + dropout=self.dropout, + num_layers=self.layers_per_block, + attn_num_head_channels=self.attention_head_dim, + add_downsample=not is_final_block, + dtype=self.dtype, + ) + else: + down_block = FlaxDownBlock2D( + in_channels=input_channel, + out_channels=output_channel, + dropout=self.dropout, + num_layers=self.layers_per_block, + add_downsample=not is_final_block, + dtype=self.dtype, + ) + + down_blocks.append(down_block) + self.down_blocks = down_blocks + + # mid + self.mid_block = FlaxUNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + dropout=self.dropout, + attn_num_head_channels=self.attention_head_dim, + dtype=self.dtype, + ) + + # up + up_blocks = [] + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(self.up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + is_final_block = i == len(block_out_channels) - 1 + + if up_block_type == "CrossAttnUpBlock2D": + up_block = FlaxCrossAttnUpBlock2D( + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + num_layers=self.layers_per_block + 1, + attn_num_head_channels=self.attention_head_dim, + add_upsample=not is_final_block, + dropout=self.dropout, + dtype=self.dtype, + ) + else: + up_block = FlaxUpBlock2D( + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + num_layers=self.layers_per_block + 1, + add_upsample=not is_final_block, + dropout=self.dropout, + dtype=self.dtype, + ) + + up_blocks.append(up_block) + prev_output_channel = output_channel + self.up_blocks = up_blocks + + # out + self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5) + self.conv_out = nn.Conv( + self.out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + def __call__( + self, + sample, + timesteps, + encoder_hidden_states, + return_dict: bool = True, + train: bool = False, + ) -> Union[FlaxUNet2DConditionOutput, Tuple]: + """r + Args: + sample (`jnp.ndarray`): (channel, height, width) noisy inputs tensor + timestep (`jnp.ndarray` or `float` or `int`): timesteps + encoder_hidden_states (`jnp.ndarray`): (channel, height, width) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a + plain tuple. + train (`bool`, *optional*, defaults to `False`): + Use deterministic functions and disable dropout when not training. + + Returns: + [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + """ + # 1. time + t_emb = self.time_proj(timesteps) + t_emb = self.time_embedding(t_emb) + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for down_block in self.down_blocks: + if isinstance(down_block, FlaxCrossAttnDownBlock2D): + sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train) + else: + sample, res_samples = down_block(sample, t_emb, deterministic=not train) + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) + + # 5. up + for up_block in self.up_blocks: + res_samples = down_block_res_samples[-(self.layers_per_block + 1) :] + down_block_res_samples = down_block_res_samples[: -(self.layers_per_block + 1)] + if isinstance(up_block, FlaxCrossAttnUpBlock2D): + sample = up_block( + sample, + temb=t_emb, + encoder_hidden_states=encoder_hidden_states, + res_hidden_states_tuple=res_samples, + deterministic=not train, + ) + else: + sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples, deterministic=not train) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = nn.silu(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return FlaxUNet2DConditionOutput(sample=sample) diff --git a/src/diffusers/models/unet_blocks.py b/src/diffusers/models/unet_blocks.py index 9e06216535..1fee670c91 100644 --- a/src/diffusers/models/unet_blocks.py +++ b/src/diffusers/models/unet_blocks.py @@ -31,6 +31,7 @@ def get_down_block( resnet_eps, resnet_act_fn, attn_num_head_channels, + resnet_groups=None, cross_attention_dim=None, downsample_padding=None, ): @@ -44,6 +45,7 @@ def get_down_block( add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, downsample_padding=downsample_padding, ) elif down_block_type == "AttnDownBlock2D": @@ -55,6 +57,7 @@ def get_down_block( add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, downsample_padding=downsample_padding, attn_num_head_channels=attn_num_head_channels, ) @@ -69,6 +72,7 @@ def get_down_block( add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, @@ -104,6 +108,7 @@ def get_down_block( add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, downsample_padding=downsample_padding, ) @@ -119,6 +124,7 @@ def get_up_block( resnet_eps, resnet_act_fn, attn_num_head_channels, + resnet_groups=None, cross_attention_dim=None, ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type @@ -132,6 +138,7 @@ def get_up_block( add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, ) elif up_block_type == "CrossAttnUpBlock2D": if cross_attention_dim is None: @@ -145,6 +152,7 @@ def get_up_block( add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, attn_num_head_channels=attn_num_head_channels, ) @@ -158,6 +166,7 @@ def get_up_block( add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, attn_num_head_channels=attn_num_head_channels, ) elif up_block_type == "SkipUpBlock2D": @@ -191,6 +200,7 @@ def get_up_block( add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, ) raise ValueError(f"{up_block_type} does not exist.") @@ -323,6 +333,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): in_channels // attn_num_head_channels, depth=1, context_dim=cross_attention_dim, + num_groups=resnet_groups, ) ) resnets.append( @@ -414,6 +425,7 @@ class AttnDownBlock2D(nn.Module): num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, + num_groups=resnet_groups, ) ) @@ -498,6 +510,7 @@ class CrossAttnDownBlock2D(nn.Module): out_channels // attn_num_head_channels, depth=1, context_dim=cross_attention_dim, + num_groups=resnet_groups, ) ) self.attentions = nn.ModuleList(attentions) @@ -807,7 +820,7 @@ class AttnSkipDownBlock2D(nn.Module): non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, - use_nin_shortcut=True, + use_in_shortcut=True, down=True, kernel="fir", ) @@ -887,7 +900,7 @@ class SkipDownBlock2D(nn.Module): non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, - use_nin_shortcut=True, + use_in_shortcut=True, down=True, kernel="fir", ) @@ -966,6 +979,7 @@ class AttnUpBlock2D(nn.Module): num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, + num_groups=resnet_groups, ) ) @@ -979,7 +993,6 @@ class AttnUpBlock2D(nn.Module): def forward(self, hidden_states, res_hidden_states_tuple, temb=None): for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] @@ -1048,6 +1061,7 @@ class CrossAttnUpBlock2D(nn.Module): out_channels // attn_num_head_channels, depth=1, context_dim=cross_attention_dim, + num_groups=resnet_groups, ) ) self.attentions = nn.ModuleList(attentions) @@ -1075,7 +1089,6 @@ class CrossAttnUpBlock2D(nn.Module): def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None): for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] @@ -1139,7 +1152,6 @@ class UpBlock2D(nn.Module): def forward(self, hidden_states, res_hidden_states_tuple, temb=None): for resnet in self.resnets: - # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] @@ -1343,7 +1355,7 @@ class AttnSkipUpBlock2D(nn.Module): non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, - use_nin_shortcut=True, + use_in_shortcut=True, up=True, kernel="fir", ) @@ -1440,7 +1452,7 @@ class SkipUpBlock2D(nn.Module): non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, - use_nin_shortcut=True, + use_in_shortcut=True, up=True, kernel="fir", ) diff --git a/src/diffusers/models/unet_blocks_flax.py b/src/diffusers/models/unet_blocks_flax.py new file mode 100644 index 0000000000..ce67eb12b1 --- /dev/null +++ b/src/diffusers/models/unet_blocks_flax.py @@ -0,0 +1,263 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import flax.linen as nn +import jax.numpy as jnp + +from .attention_flax import FlaxSpatialTransformer +from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D + + +class FlaxCrossAttnDownBlock2D(nn.Module): + in_channels: int + out_channels: int + dropout: float = 0.0 + num_layers: int = 1 + attn_num_head_channels: int = 1 + add_downsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + attentions = [] + + for i in range(self.num_layers): + in_channels = self.in_channels if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=in_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + attn_block = FlaxSpatialTransformer( + in_channels=self.out_channels, + n_heads=self.attn_num_head_channels, + d_head=self.out_channels // self.attn_num_head_channels, + depth=1, + dtype=self.dtype, + ) + attentions.append(attn_block) + + self.resnets = resnets + self.attentions = attentions + + if self.add_downsample: + self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) + output_states += (hidden_states,) + + if self.add_downsample: + hidden_states = self.downsample(hidden_states) + output_states += (hidden_states,) + + return hidden_states, output_states + + +class FlaxDownBlock2D(nn.Module): + in_channels: int + out_channels: int + dropout: float = 0.0 + num_layers: int = 1 + add_downsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + + for i in range(self.num_layers): + in_channels = self.in_channels if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=in_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + self.resnets = resnets + + if self.add_downsample: + self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, temb, deterministic=True): + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + output_states += (hidden_states,) + + if self.add_downsample: + hidden_states = self.downsample(hidden_states) + output_states += (hidden_states,) + + return hidden_states, output_states + + +class FlaxCrossAttnUpBlock2D(nn.Module): + in_channels: int + out_channels: int + prev_output_channel: int + dropout: float = 0.0 + num_layers: int = 1 + attn_num_head_channels: int = 1 + add_upsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + attentions = [] + + for i in range(self.num_layers): + res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels + resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + attn_block = FlaxSpatialTransformer( + in_channels=self.out_channels, + n_heads=self.attn_num_head_channels, + d_head=self.out_channels // self.attn_num_head_channels, + depth=1, + dtype=self.dtype, + ) + attentions.append(attn_block) + + self.resnets = resnets + self.attentions = attentions + + if self.add_upsample: + self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) + + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) + + if self.add_upsample: + hidden_states = self.upsample(hidden_states) + + return hidden_states + + +class FlaxUpBlock2D(nn.Module): + in_channels: int + out_channels: int + prev_output_channel: int + dropout: float = 0.0 + num_layers: int = 1 + add_upsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + + for i in range(self.num_layers): + res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels + resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + self.resnets = resnets + + if self.add_upsample: + self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) + + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + + if self.add_upsample: + hidden_states = self.upsample(hidden_states) + + return hidden_states + + +class FlaxUNetMidBlock2DCrossAttn(nn.Module): + in_channels: int + dropout: float = 0.0 + num_layers: int = 1 + attn_num_head_channels: int = 1 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + # there is always at least one resnet + resnets = [ + FlaxResnetBlock2D( + in_channels=self.in_channels, + out_channels=self.in_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + ] + + attentions = [] + + for _ in range(self.num_layers): + attn_block = FlaxSpatialTransformer( + in_channels=self.in_channels, + n_heads=self.attn_num_head_channels, + d_head=self.in_channels // self.attn_num_head_channels, + depth=1, + dtype=self.dtype, + ) + attentions.append(attn_block) + + res_block = FlaxResnetBlock2D( + in_channels=self.in_channels, + out_channels=self.in_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + self.resnets = resnets + self.attentions = attentions + + def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + + return hidden_states diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index c0a185784c..fe89b41c07 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -59,6 +59,7 @@ class Encoder(nn.Module): down_block_types=("DownEncoderBlock2D",), block_out_channels=(64,), layers_per_block=2, + norm_num_groups=32, act_fn="silu", double_z=True, ): @@ -86,6 +87,7 @@ class Encoder(nn.Module): resnet_eps=1e-6, downsample_padding=0, resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, attn_num_head_channels=None, temb_channels=None, ) @@ -99,13 +101,12 @@ class Encoder(nn.Module): output_scale_factor=1, resnet_time_scale_shift="default", attn_num_head_channels=None, - resnet_groups=32, + resnet_groups=norm_num_groups, temb_channels=None, ) # out - num_groups_out = 32 - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups_out, eps=1e-6) + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() conv_out_channels = 2 * out_channels if double_z else out_channels @@ -138,6 +139,7 @@ class Decoder(nn.Module): up_block_types=("UpDecoderBlock2D",), block_out_channels=(64,), layers_per_block=2, + norm_num_groups=32, act_fn="silu", ): super().__init__() @@ -156,7 +158,7 @@ class Decoder(nn.Module): output_scale_factor=1, resnet_time_scale_shift="default", attn_num_head_channels=None, - resnet_groups=32, + resnet_groups=norm_num_groups, temb_channels=None, ) @@ -178,6 +180,7 @@ class Decoder(nn.Module): add_upsample=not is_final_block, resnet_eps=1e-6, resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, attn_num_head_channels=None, temb_channels=None, ) @@ -185,8 +188,7 @@ class Decoder(nn.Module): prev_output_channel = output_channel # out - num_groups_out = 32 - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=1e-6) + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) @@ -338,7 +340,10 @@ class DiagonalGaussianDistribution(object): self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: - x = self.mean + self.std * torch.randn(self.mean.shape, generator=generator, device=self.parameters.device) + device = self.parameters.device + sample_device = "cpu" if device.type == "mps" else device + sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device) + x = self.mean + self.std * sample return x def kl(self, other=None): @@ -368,6 +373,27 @@ class DiagonalGaussianDistribution(object): class VQModel(ModelMixin, ConfigMixin): + r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray + Kavukcuoglu. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(64,)`): Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): TODO + num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. + """ + @register_to_config def __init__( self, @@ -381,6 +407,7 @@ class VQModel(ModelMixin, ConfigMixin): latent_channels: int = 3, sample_size: int = 32, num_vq_embeddings: int = 256, + norm_num_groups: int = 32, ): super().__init__() @@ -392,6 +419,7 @@ class VQModel(ModelMixin, ConfigMixin): block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, + norm_num_groups=norm_num_groups, double_z=False, ) @@ -409,6 +437,7 @@ class VQModel(ModelMixin, ConfigMixin): block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, + norm_num_groups=norm_num_groups, ) def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: @@ -437,6 +466,12 @@ class VQModel(ModelMixin, ConfigMixin): return DecoderOutput(sample=dec) def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ x = sample h = self.encode(x).latents dec = self.decode(h).sample @@ -448,6 +483,26 @@ class VQModel(ModelMixin, ConfigMixin): class AutoencoderKL(ModelMixin, ConfigMixin): + r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma + and Max Welling. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(64,)`): Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): TODO + """ + @register_to_config def __init__( self, @@ -459,6 +514,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): layers_per_block: int = 1, act_fn: str = "silu", latent_channels: int = 4, + norm_num_groups: int = 32, sample_size: int = 32, ): super().__init__() @@ -471,6 +527,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, + norm_num_groups=norm_num_groups, double_z=True, ) @@ -481,6 +538,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): up_block_types=up_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, act_fn=act_fn, ) @@ -507,12 +565,24 @@ class AutoencoderKL(ModelMixin, ConfigMixin): return DecoderOutput(sample=dec) def forward( - self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, ) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ x = sample posterior = self.encode(x).latent_dist if sample_posterior: - z = posterior.sample() + z = posterior.sample(generator=generator) else: z = posterior.mode() dec = self.decode(z).sample diff --git a/src/diffusers/onnx_utils.py b/src/diffusers/onnx_utils.py new file mode 100644 index 0000000000..3c2a0b4829 --- /dev/null +++ b/src/diffusers/onnx_utils.py @@ -0,0 +1,187 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import shutil +from pathlib import Path +from typing import Optional, Union + +import numpy as np + +from huggingface_hub import hf_hub_download + +from .utils import is_onnx_available, logging + + +if is_onnx_available(): + import onnxruntime as ort + + +ONNX_WEIGHTS_NAME = "model.onnx" + + +logger = logging.get_logger(__name__) + + +class OnnxRuntimeModel: + def __init__(self, model=None, **kwargs): + logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.") + self.model = model + self.model_save_dir = kwargs.get("model_save_dir", None) + self.latest_model_name = kwargs.get("latest_model_name", "model.onnx") + + def __call__(self, **kwargs): + inputs = {k: np.array(v) for k, v in kwargs.items()} + return self.model.run(None, inputs) + + @staticmethod + def load_model(path: Union[str, Path], provider=None): + """ + Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider` + + Arguments: + path (`str` or `Path`): + Directory from which to load + provider(`str`, *optional*): + Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider` + """ + if provider is None: + logger.info("No onnxruntime provider specified, using CPUExecutionProvider") + provider = "CPUExecutionProvider" + + return ort.InferenceSession(path, providers=[provider]) + + def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + [`~optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained`] class method. It will always save the + latest_model_name. + + Arguments: + save_directory (`str` or `Path`): + Directory where to save the model file. + file_name(`str`, *optional*): + Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to save the + model with a different name. + """ + model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME + + src_path = self.model_save_dir.joinpath(self.latest_model_name) + dst_path = Path(save_directory).joinpath(model_file_name) + if not src_path.samefile(dst_path): + shutil.copyfile(src_path, dst_path) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + **kwargs, + ): + """ + Save a model to a directory, so that it can be re-loaded using the [`~OnnxModel.from_pretrained`] class + method.: + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + """ + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + # saving model weights/files + self._save_pretrained(save_directory, **kwargs) + + @classmethod + def _from_pretrained( + cls, + model_id: Union[str, Path], + use_auth_token: Optional[Union[bool, str, None]] = None, + revision: Optional[Union[str, None]] = None, + force_download: bool = False, + cache_dir: Optional[str] = None, + file_name: Optional[str] = None, + provider: Optional[str] = None, + **kwargs, + ): + """ + Load a model from a directory or the HF Hub. + + Arguments: + model_id (`str` or `Path`): + Directory from which to load + use_auth_token (`str` or `bool`): + Is needed to load models from a private or gated repository + revision (`str`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id + cache_dir (`Union[str, Path]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + file_name(`str`): + Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load + different model files from the same repository or directory. + provider(`str`): + The ONNX runtime provider, e.g. `CPUExecutionProvider` or `CUDAExecutionProvider`. + kwargs (`Dict`, *optional*): + kwargs will be passed to the model during initialization + """ + model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME + # load model from local directory + if os.path.isdir(model_id): + model = OnnxRuntimeModel.load_model(os.path.join(model_id, model_file_name), provider=provider) + kwargs["model_save_dir"] = Path(model_id) + # load model from hub + else: + # download model + model_cache_path = hf_hub_download( + repo_id=model_id, + filename=model_file_name, + use_auth_token=use_auth_token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + ) + kwargs["model_save_dir"] = Path(model_cache_path).parent + kwargs["latest_model_name"] = Path(model_cache_path).name + model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider) + return cls(model=model, **kwargs) + + @classmethod + def from_pretrained( + cls, + model_id: Union[str, Path], + force_download: bool = True, + use_auth_token: Optional[str] = None, + cache_dir: Optional[str] = None, + **model_kwargs, + ): + revision = None + if len(str(model_id).split("@")) == 2: + model_id, revision = model_id.split("@") + + return cls._from_pretrained( + model_id=model_id, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + use_auth_token=use_auth_token, + **model_kwargs, + ) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index fc2bc7bcf4..847513bf15 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -23,13 +23,17 @@ from typing import List, Optional, Union import numpy as np import torch +import diffusers import PIL from huggingface_hub import snapshot_download from PIL import Image from tqdm.auto import tqdm from .configuration_utils import ConfigMixin -from .utils import DIFFUSERS_CACHE, BaseOutput, logging +from .modeling_utils import WEIGHTS_NAME +from .onnx_utils import ONNX_WEIGHTS_NAME +from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME +from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, logging INDEX_FILE = "diffusion_pytorch_model.bin" @@ -43,6 +47,7 @@ LOADABLE_CLASSES = { "ModelMixin": ["save_pretrained", "from_pretrained"], "SchedulerMixin": ["save_config", "from_config"], "DiffusionPipeline": ["save_pretrained", "from_pretrained"], + "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], }, "transformers": { "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], @@ -72,7 +77,20 @@ class ImagePipelineOutput(BaseOutput): class DiffusionPipeline(ConfigMixin): + r""" + Base class for all models. + [`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines + and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to: + + - move all PyTorch modules to the device of your choice + - enabling/disabling the progress bar for the denoising iteration + + Class attributes: + + - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all + components of the diffusion pipeline. + """ config_name = "model_index.json" def register_modules(self, **kwargs): @@ -80,7 +98,7 @@ class DiffusionPipeline(ConfigMixin): from diffusers import pipelines for name, module in kwargs.items(): - # retrive library + # retrieve library library = module.__module__.split(".")[0] # check if the module is a pipeline module @@ -94,7 +112,7 @@ class DiffusionPipeline(ConfigMixin): if library not in LOADABLE_CLASSES or is_pipeline_module: library = pipeline_dir - # retrive class_name + # retrieve class_name class_name = module.__class__.__name__ register_dict = {name: (library, class_name)} @@ -106,6 +124,15 @@ class DiffusionPipeline(ConfigMixin): setattr(self, name, module) def save_pretrained(self, save_directory: Union[str, os.PathLike]): + """ + Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to + a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading + method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + """ self.save_config(save_directory) model_index_dict = dict(self.config) @@ -146,6 +173,10 @@ class DiffusionPipeline(ConfigMixin): @property def device(self) -> torch.device: + r""" + Returns: + `torch.device`: The torch device on which the pipeline is located. + """ module_names, _ = self.extract_init_dict(dict(self.config)) for name in module_names.keys(): module = getattr(self, name) @@ -156,7 +187,94 @@ class DiffusionPipeline(ConfigMixin): @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): r""" - Add docstrings + Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights. + + The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on + https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like + `CompVis/ldm-text2im-large-256`. + - A path to a *directory* containing pipeline weights saved using + [`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. specify the folder name here. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the + specific pipeline class. The overritten components are then directly passed to the pipelines `__init__` + method. See example below for more information. + + + + Passing `use_auth_token=True`` is required when you want to use a private model, *e.g.* + `"CompVis/stable-diffusion-v1-4"` + + + + + + Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use + this method in a firewalled environment. + + + + Examples: + + ```py + >>> from diffusers import DiffusionPipeline + + >>> # Download pipeline from huggingface.co and cache. + >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") + + >>> # Download pipeline that requires an authorization token + >>> # For more information on access tokens, please refer to this section + >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens) + >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True) + + >>> # Download pipeline, but overwrite scheduler + >>> from diffusers import LMSDiscreteScheduler + + >>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + >>> pipeline = DiffusionPipeline.from_pretrained( + ... "CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True + ... ) + ``` """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) resume_download = kwargs.pop("resume_download", False) @@ -165,10 +283,26 @@ class DiffusionPipeline(ConfigMixin): use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None) + provider = kwargs.pop("provider", None) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained if not os.path.isdir(pretrained_model_name_or_path): + config_dict = cls.get_config_dict( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + ) + # make sure we only download sub-folders and `diffusers` filenames + folder_names = [k for k in config_dict.keys() if not k.startswith("_")] + allow_patterns = [os.path.join(k, "*") for k in folder_names] + allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name] + + # download all allow_patterns cached_folder = snapshot_download( pretrained_model_name_or_path, cache_dir=cache_dir, @@ -177,6 +311,7 @@ class DiffusionPipeline(ConfigMixin): local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, + allow_patterns=allow_patterns, ) else: cached_folder = pretrained_model_name_or_path @@ -259,6 +394,8 @@ class DiffusionPipeline(ConfigMixin): loading_kwargs = {} if issubclass(class_obj, torch.nn.Module): loading_kwargs["torch_dtype"] = torch_dtype + if issubclass(class_obj, diffusers.OnnxRuntimeModel): + loading_kwargs["provider"] = provider # check if the module is in a subdirectory if os.path.isdir(os.path.join(cached_folder, name)): diff --git a/src/diffusers/pipelines/README.md b/src/diffusers/pipelines/README.md index 940bcfa685..7957a8c364 100644 --- a/src/diffusers/pipelines/README.md +++ b/src/diffusers/pipelines/README.md @@ -40,8 +40,8 @@ available a colab notebook to directly try them out. | [score_sde_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* | | [score_sde_vp](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* | | [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) -| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb) -| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/in_painting_with_stable_diffusion_using_diffusers.ipynb) +| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) +| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) | [stochastic_karras_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | *Unconditional Image Generation* | **Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers. @@ -70,7 +70,7 @@ not be used for training. If you want to store the gradients during the forward ## Contribution -We are more than happy about any contribution to the offically supported pipelines 🤗. We aspire +We are more than happy about any contribution to the officially supported pipelines 🤗. We aspire all of our pipelines to be **self-contained**, **easy-to-tweak**, **beginner-friendly** and for **one-purpose-only**. - **Self-contained**: A pipeline shall be as self-contained as possible. More specifically, this means that all functionality should be either directly defined in the pipeline file iteslf, should be inherited from (and only from) the [`DiffusionPipeline` class](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L56) or be directly attached to the model and scheduler components of the pipeline. @@ -134,7 +134,7 @@ with autocast("cuda"): images[0].save("fantasy_landscape.png") ``` -You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb) +You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) ### Tweak prompts reusing seeds and latents @@ -179,4 +179,4 @@ with autocast("cuda"): images[0].save("cat_on_bench.png") ``` -You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/in_painting_with_stable_diffusion_using_diffusers.ipynb) +You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 40ac346749..3e2aeb4fb2 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -1,4 +1,4 @@ -from ..utils import is_transformers_available +from ..utils import is_onnx_available, is_transformers_available from .ddim import DDIMPipeline from .ddpm import DDPMPipeline from .latent_diffusion_uncond import LDMPipeline @@ -14,3 +14,6 @@ if is_transformers_available(): StableDiffusionInpaintPipeline, StableDiffusionPipeline, ) + +if is_transformers_available() and is_onnx_available(): + from .stable_diffusion import StableDiffusionOnnxPipeline diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index ac8b032e88..95b49e045f 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -52,21 +52,26 @@ class DDIMPipeline(DiffusionPipeline): ) -> Union[ImagePipelineOutput, Tuple]: r""" Args: - batch_size (:obj:`int`, *optional*, defaults to 1): + batch_size (`int`, *optional*, defaults to 1): The number of images to generate. - generator (:obj:`torch.Generator`, *optional*): + generator (`torch.Generator`, *optional*): A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. - eta (:obj:`float`, *optional*, defaults to 0.0): + eta (`float`, *optional*, defaults to 0.0): The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM). - num_inference_steps (:obj:`int`, *optional*, defaults to 50): + num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - output_type (:obj:`str`, *optional*, defaults to :obj:`"pil"`): + output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. - return_dict (:obj:`bool`, *optional*, defaults to :obj:`True`): + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. """ if "torch_device" in kwargs: diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index b2af8a6c4f..b7f7093e37 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -50,16 +50,21 @@ class DDPMPipeline(DiffusionPipeline): ) -> Union[ImagePipelineOutput, Tuple]: r""" Args: - batch_size (:obj:`int`, *optional*, defaults to 1): + batch_size (`int`, *optional*, defaults to 1): The number of images to generate. - generator (:obj:`torch.Generator`, *optional*): + generator (`torch.Generator`, *optional*): A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. - output_type (:obj:`str`, *optional*, defaults to :obj:`"pil"`): + output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. - return_dict (:obj:`bool`, *optional*, defaults to :obj:`True`): + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. """ if "torch_device" in kwargs: device = kwargs.pop("torch_device") diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index 8d277d9452..8caa11dbdf 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -85,9 +85,14 @@ class LDMTextToImagePipeline(DiffusionPipeline): deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*): Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. """ if "torch_device" in kwargs: device = kwargs.pop("torch_device") @@ -686,7 +691,6 @@ class LDMBertModel(LDMBertPreTrainedModel): output_hidden_states=None, return_dict=None, ): - outputs = self.model( input_ids, attention_mask=attention_mask, diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py index ec1e853a88..5574b65df9 100644 --- a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -38,7 +38,6 @@ class LDMPipeline(DiffusionPipeline): return_dict: bool = True, **kwargs, ) -> Union[Tuple, ImagePipelineOutput]: - r""" Args: batch_size (`int`, *optional*, defaults to 1): @@ -51,9 +50,14 @@ class LDMPipeline(DiffusionPipeline): expense of slower inference. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. """ if "torch_device" in kwargs: diff --git a/src/diffusers/pipelines/pndm/pipeline_pndm.py b/src/diffusers/pipelines/pndm/pipeline_pndm.py index 9485727854..ae6c10e9e9 100644 --- a/src/diffusers/pipelines/pndm/pipeline_pndm.py +++ b/src/diffusers/pipelines/pndm/pipeline_pndm.py @@ -30,7 +30,7 @@ class PNDMPipeline(DiffusionPipeline): library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Parameters: - unet (:obj:`UNet2DModel`): U-Net architecture to denoise the encoded image latents. + unet (`UNet2DModel`): U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): The `PNDMScheduler` to be used in combination with `unet` to denoise the encoded image. """ @@ -55,20 +55,22 @@ class PNDMPipeline(DiffusionPipeline): ) -> Union[ImagePipelineOutput, Tuple]: r""" Args: - batch_size (:obj:`int`, `optional`, defaults to 1): The number of images to generate. - num_inference_steps (: - obj:`int`, `optional`, defaults to 50): The number of denoising steps. More denoising steps usually - lead to a higher quality image at the expense of slower inference. - generator (: - obj:`torch.Generator`, `optional`): A [torch + batch_size (`int`, `optional`, defaults to 1): The number of images to generate. + num_inference_steps (`int`, `optional`, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator`, `optional`): A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. - output_type (: - obj:`str`, `optional`, defaults to :obj:`"pil"`): The output format of the generate image. Choose - between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. - return_dict (: - obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not to return a + output_type (`str`, `optional`, defaults to `"pil"`): The output format of the generate image. Choose + between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, `optional`, defaults to `True`): Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. """ # For more information on the sampling method you can take a look at Algorithm 2 of # the official paper: https://arxiv.org/pdf/2202.09778.pdf diff --git a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py index 1f5cfe763e..b29795e7f6 100644 --- a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -36,16 +36,21 @@ class ScoreSdeVePipeline(DiffusionPipeline): ) -> Union[ImagePipelineOutput, Tuple]: r""" Args: - batch_size (:obj:`int`, *optional*, defaults to 1): + batch_size (`int`, *optional*, defaults to 1): The number of images to generate. - generator (:obj:`torch.Generator`, *optional*): + generator (`torch.Generator`, *optional*): A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. - output_type (:obj:`str`, *optional*, defaults to :obj:`"pil"`): + output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. - return_dict (:obj:`bool`, *optional*, defaults to :obj:`True`): + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. """ if "torch_device" in kwargs: @@ -75,7 +80,7 @@ class ScoreSdeVePipeline(DiffusionPipeline): sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device) # correction step - for _ in range(self.scheduler.correct_steps): + for _ in range(self.scheduler.config.correct_steps): model_output = self.unet(sample, sigma_t).sample sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample diff --git a/src/diffusers/pipelines/stable_diffusion/README.md b/src/diffusers/pipelines/stable_diffusion/README.md index 64f17a3f11..0e6cab2e11 100644 --- a/src/diffusers/pipelines/stable_diffusion/README.md +++ b/src/diffusers/pipelines/stable_diffusion/README.md @@ -12,7 +12,7 @@ The summary of the model is the following: - Stable Diffusion has the same architecture as [Latent Diffusion](https://arxiv.org/abs/2112.10752) but uses a frozen CLIP Text Encoder instead of training the text encoder jointly with the diffusion model. - An in-detail explanation of the Stable Diffusion model can be found under [Stable Diffusion with 🧨 Diffusers](https://huggingface.co/blog/stable_diffusion). -- If you don't want to rely on the Hugging Face Hub and having to pass a authentification token, you can +- If you don't want to rely on the Hugging Face Hub and having to pass a authentication token, you can download the weights with `git lfs install; git clone https://huggingface.co/CompVis/stable-diffusion-v1-4` and instead pass the local path to the cloned folder to `from_pretrained` as shown below. - Stable Diffusion can work with a variety of different samplers as is shown below. @@ -21,8 +21,8 @@ download the weights with `git lfs install; git clone https://huggingface.co/Com | Pipeline | Tasks | Colab |---|---|:---:| | [pipeline_stable_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) -| [pipeline_stable_diffusion_img2img](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb) -| [pipeline_stable_diffusion_inpaint](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/in_painting_with_stable_diffusion_using_diffusers.ipynb) +| [pipeline_stable_diffusion_img2img](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) +| [pipeline_stable_diffusion_inpaint](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) ## Examples: diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 8bfa394c71..5ffda93f17 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -1,4 +1,3 @@ -# flake8: noqa from dataclasses import dataclass from typing import List, Union @@ -7,7 +6,7 @@ import numpy as np import PIL from PIL import Image -from ...utils import BaseOutput, is_transformers_available +from ...utils import BaseOutput, is_onnx_available, is_transformers_available @dataclass @@ -33,3 +32,6 @@ if is_transformers_available(): from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline from .safety_checker import StableDiffusionSafetyChecker + +if is_transformers_available() and is_onnx_available(): + from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index ba3e55dfa6..1272fe64e7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -6,6 +6,7 @@ import torch from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler @@ -53,6 +54,21 @@ class StableDiffusionPipeline(DiffusionPipeline): ): super().__init__() scheduler = scheduler.set_format("pt") + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + warnings.warn( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file", + DeprecationWarning, + ) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -136,13 +152,14 @@ class StableDiffusionPipeline(DiffusionPipeline): tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. Returns: - `~pipelines.stable_diffusion.StableDiffusionPipelineOutput` if `return_dict` is True, otherwise a tuple. + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. @@ -198,29 +215,29 @@ class StableDiffusionPipeline(DiffusionPipeline): text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_device = "cpu" if self.device.type == "mps" else self.device latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) if latents is None: latents = torch.randn( latents_shape, generator=generator, - device=self.device, + device=latents_device, dtype=text_embeddings.dtype, ) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) + latents = latents.to(latents_device) # set timesteps - accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) - extra_set_kwargs = {} - if accepts_offset: - extra_set_kwargs["offset"] = 1 + self.scheduler.set_timesteps(num_inference_steps) + self.scheduler.timesteps = self.scheduler.timesteps.to(latents_device) - self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) - self.scheduler.timesteps = torch.tensor(self.scheduler.timesteps, device=self.device) - - # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas + # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas if isinstance(self.scheduler, LMSDiscreteScheduler): latents = latents * self.scheduler.sigmas[0] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 5c63992090..e7adb4d1a3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -1,4 +1,5 @@ import inspect +import warnings from typing import List, Optional, Union import numpy as np @@ -7,6 +8,7 @@ import torch import PIL from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler @@ -64,6 +66,21 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ): super().__init__() scheduler = scheduler.set_format("pt") + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + warnings.warn( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file", + DeprecationWarning, + ) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -99,7 +116,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): back to computing attention in one step. """ # set slice_size = `None` to disable `set_attention_slice` - self.enable_attention_slice(None) + self.enable_attention_slicing(None) @torch.no_grad() def __call__( @@ -146,13 +163,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. Returns: - `~pipelines.stable_diffusion.StableDiffusionPipelineOutput` if `return_dict` is True, otherwise a tuple. + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. @@ -168,16 +186,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") # set timesteps - accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) - extra_set_kwargs = {} - offset = 0 - if accepts_offset: - offset = 1 - extra_set_kwargs["offset"] = 1 + self.scheduler.set_timesteps(num_inference_steps) - self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) - - if not isinstance(init_image, torch.FloatTensor): + if isinstance(init_image, PIL.Image.Image): init_image = preprocess(init_image) # encode the init image into latents and scale the latents @@ -189,6 +200,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): init_latents = torch.cat([init_latents] * batch_size) # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) if isinstance(self.scheduler, LMSDiscreteScheduler): @@ -248,7 +260,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas + # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas if isinstance(self.scheduler, LMSDiscreteScheduler): sigma = self.scheduler.sigmas[t_index] # the model input needs to be scaled to match the continuous ODE formulation in K-LMS diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 9e6b5c9a9b..b9ad36f1a2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -1,4 +1,5 @@ import inspect +import warnings from typing import List, Optional, Union import numpy as np @@ -8,13 +9,18 @@ import PIL from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, PNDMScheduler +from ...utils import logging from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker +logger = logging.get_logger(__name__) + + def preprocess_image(image): w, h = image.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 @@ -78,6 +84,22 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ): super().__init__() scheduler = scheduler.set_format("pt") + logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + warnings.warn( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file", + DeprecationWarning, + ) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -113,7 +135,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): back to computing attention in one step. """ # set slice_size = `None` to disable `set_attention_slice` - self.enable_attention_slice(None) + self.enable_attention_slicing(None) @torch.no_grad() def __call__( @@ -140,8 +162,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): process. This is the image whose masked region will be inpainted. mask_image (`torch.FloatTensor` or `PIL.Image.Image`): `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be - replaced by noise and therefore repainted, while black pixels will be preserved. The mask image will be - converted to a single channel (luminance) before use. + replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a + PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should + contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` is 1, the denoising process will be run on the masked area for the full number of iterations specified @@ -164,13 +187,14 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): deterministic. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. Returns: - `~pipelines.stable_diffusion.StableDiffusionPipelineOutput` if `return_dict` is True, otherwise a tuple. + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. @@ -186,20 +210,15 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") # set timesteps - accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) - extra_set_kwargs = {} - offset = 0 - if accepts_offset: - offset = 1 - extra_set_kwargs["offset"] = 1 - - self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + self.scheduler.set_timesteps(num_inference_steps) # preprocess image - init_image = preprocess_image(init_image).to(self.device) + if not isinstance(init_image, torch.FloatTensor): + init_image = preprocess_image(init_image) + init_image = init_image.to(self.device) # encode the init image into latents and scale the latents - init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist + init_latent_dist = self.vae.encode(init_image).latent_dist init_latents = init_latent_dist.sample(generator=generator) init_latents = 0.18215 * init_latents @@ -209,14 +228,17 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): init_latents_orig = init_latents # preprocess mask - mask = preprocess_mask(mask_image).to(self.device) - mask = torch.cat([mask] * batch_size) + if not isinstance(mask_image, torch.FloatTensor): + mask_image = preprocess_mask(mask_image) + mask_image = mask_image.to(self.device) + mask = torch.cat([mask_image] * batch_size) # check sizes if not mask.shape == init_latents.shape: raise ValueError("The mask and init_image should be the same size!") # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) timesteps = self.scheduler.timesteps[-init_timestep] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py new file mode 100644 index 0000000000..ccba29ade5 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -0,0 +1,160 @@ +import inspect +from typing import List, Optional, Union + +import numpy as np + +from transformers import CLIPFeatureExtractor, CLIPTokenizer + +from ...onnx_utils import OnnxRuntimeModel +from ...pipeline_utils import DiffusionPipeline +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from . import StableDiffusionPipelineOutput + + +class StableDiffusionOnnxPipeline(DiffusionPipeline): + vae_decoder: OnnxRuntimeModel + text_encoder: OnnxRuntimeModel + tokenizer: CLIPTokenizer + unet: OnnxRuntimeModel + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + safety_checker: OnnxRuntimeModel + feature_extractor: CLIPFeatureExtractor + + def __init__( + self, + vae_decoder: OnnxRuntimeModel, + text_encoder: OnnxRuntimeModel, + tokenizer: CLIPTokenizer, + unet: OnnxRuntimeModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: OnnxRuntimeModel, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + scheduler = scheduler.set_format("np") + self.register_modules( + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + latents: Optional[np.ndarray] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + **kwargs, + ): + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get prompt text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_embeddings = self.text_encoder(input_ids=text_input.input_ids.astype(np.int32))[0] + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + max_length = text_input.input_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ) + uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) + + # get the initial random noise unless the user supplied it + latents_shape = (batch_size, 4, height // 8, width // 8) + if latents is None: + latents = np.random.randn(*latents_shape).astype(np.float32) + elif latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents * self.scheduler.sigmas[0] + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[i] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + + # predict the noise residual + noise_pred = self.unet( + sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings + ) + noise_pred = noise_pred[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vae_decoder(latent_sample=latents)[0] + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + # run safety checker + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") + image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index 5a315c3367..09de92eeb1 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -13,7 +13,7 @@ logger = logging.get_logger(__name__) def cosine_distance(image_embeds, text_embeds): normalized_image_embeds = nn.functional.normalize(image_embeds) normalized_text_embeds = nn.functional.normalize(text_embeds) - return torch.mm(normalized_image_embeds, normalized_text_embeds.T) + return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) class StableDiffusionSafetyChecker(PreTrainedModel): @@ -78,3 +78,29 @@ class StableDiffusionSafetyChecker(PreTrainedModel): ) return images, has_nsfw_concepts + + @torch.inference_mode() + def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) + cos_dist = cosine_distance(image_embeds, self.concept_embeds) + + # increase this value to create a stronger `nsfw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment + # special_scores = special_scores.round(decimals=3) + special_care = torch.any(special_scores > 0, dim=1) + special_adjustment = special_care * 0.01 + special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) + + concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment + # concept_scores = concept_scores.round(decimals=3) + has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) + + images[has_nsfw_concepts] = 0.0 # black image + + return images, has_nsfw_concepts diff --git a/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py b/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py index 29b6032de9..1984a25ac0 100644 --- a/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +++ b/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py @@ -45,19 +45,24 @@ class KarrasVePipeline(DiffusionPipeline): ) -> Union[Tuple, ImagePipelineOutput]: r""" Args: - batch_size (:obj:`int`, *optional*, defaults to 1): + batch_size (`int`, *optional*, defaults to 1): The number of images to generate. - generator (:obj:`torch.Generator`, *optional*): + generator (`torch.Generator`, *optional*): A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. - num_inference_steps (:obj:`int`, *optional*, defaults to 50): + num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - output_type (:obj:`str`, *optional*, defaults to :obj:`"pil"`): + output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. - return_dict (:obj:`bool`, *optional*, defaults to :obj:`True`): + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. + + Returns: + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. """ if "torch_device" in kwargs: device = kwargs.pop("torch_device") diff --git a/src/diffusers/schedulers/README.md b/src/diffusers/schedulers/README.md index 3b1eb9342e..edf2299446 100644 --- a/src/diffusers/schedulers/README.md +++ b/src/diffusers/schedulers/README.md @@ -1,7 +1,7 @@ # Schedulers - Schedulers are the algorithms to use diffusion models in inference as well as for training. They include the noise schedules and define algorithm-specific diffusion steps. -- Schedulers can be used interchangable between diffusion models in inference to find the preferred trade-off between speed and generation quality. +- Schedulers can be used interchangeable between diffusion models in inference to find the preferred trade-off between speed and generation quality. - Schedulers are available in numpy, but can easily be transformed into PyTorch. ## API diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 20c25f3518..495f30d9fa 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -12,17 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..utils import is_scipy_available -from .scheduling_ddim import DDIMScheduler -from .scheduling_ddpm import DDPMScheduler -from .scheduling_karras_ve import KarrasVeScheduler -from .scheduling_pndm import PNDMScheduler -from .scheduling_sde_ve import ScoreSdeVeScheduler -from .scheduling_sde_vp import ScoreSdeVpScheduler -from .scheduling_utils import SchedulerMixin +from ..utils import is_flax_available, is_scipy_available, is_torch_available + + +if is_torch_available(): + from .scheduling_ddim import DDIMScheduler + from .scheduling_ddpm import DDPMScheduler + from .scheduling_karras_ve import KarrasVeScheduler + from .scheduling_pndm import PNDMScheduler + from .scheduling_sde_ve import ScoreSdeVeScheduler + from .scheduling_sde_vp import ScoreSdeVpScheduler + from .scheduling_utils import SchedulerMixin +else: + from ..utils.dummy_pt_objects import * # noqa F403 + +if is_flax_available(): + from .scheduling_ddim_flax import FlaxDDIMScheduler + from .scheduling_ddpm_flax import FlaxDDPMScheduler + from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler + from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler + from .scheduling_pndm_flax import FlaxPNDMScheduler + from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler +else: + from ..utils.dummy_flax_objects import * # noqa F403 if is_scipy_available(): from .scheduling_lms_discrete import LMSDiscreteScheduler else: - from ..utils.dummy_scipy_objects import * # noqa F403 + from ..utils.dummy_torch_and_scipy_objects import * # noqa F403 diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index d15c55410c..32be871f0b 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -16,6 +16,7 @@ # and https://github.com/hojonathanho/diffusion import math +import warnings from typing import Optional, Tuple, Union import numpy as np @@ -59,6 +60,11 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with non-Markovian guidance. + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + For more details, see the original paper: https://arxiv.org/abs/2010.02502 Args: @@ -68,12 +74,18 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, optional): TODO - timestep_values (`np.ndarray`, optional): TODO + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. set_alpha_to_one (`bool`, default `True`): - if alpha for final step is 1 or the final alpha of the "non-previous" one. + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. """ @@ -86,9 +98,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, - timestep_values: Optional[np.ndarray] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, + steps_offset: int = 0, tensor_format: str = "pt", ): if trained_betas is not None: @@ -109,7 +121,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 - # `set_alpha_to_one` decides whether we set this paratemer simply to one or + # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] @@ -130,19 +142,31 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): return variance - def set_timesteps(self, num_inference_steps: int, offset: int = 0): + def set_timesteps(self, num_inference_steps: int, **kwargs): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. - offset (`int`): TODO """ + + offset = self.config.steps_offset + + if "offset" in kwargs: + warnings.warn( + "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." + " Please pass `steps_offset` to `__init__` instead.", + DeprecationWarning, + ) + + offset = kwargs["offset"] + self.num_inference_steps = num_inference_steps - self.timesteps = np.arange( - 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps - )[::-1].copy() + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy() self.timesteps += offset self.set_format(tensor_format=self.tensor_format) @@ -171,7 +195,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: - `SchedulerOutput`: updated sample in the diffusion chain. + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: @@ -187,7 +213,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): # - pred_original_sample -> f_theta(x_t, t) or x_0 # - std_dev_t -> sigma_t # - eta -> η - # - pred_sample_direction -> "direction pointingc to x_t" + # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) @@ -242,6 +268,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): noise: Union[torch.FloatTensor, np.ndarray], timesteps: Union[torch.IntTensor, np.ndarray], ) -> Union[torch.FloatTensor, np.ndarray]: + if self.tensor_format == "pt": + timesteps = timesteps.to(self.alphas_cumprod.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py new file mode 100644 index 0000000000..dd3c2ac85d --- /dev/null +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -0,0 +1,274 @@ +# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax.numpy as jnp +from jax import random + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return jnp.array(betas, dtype=jnp.float32) + + +@flax.struct.dataclass +class DDIMSchedulerState: + # setable values + timesteps: jnp.ndarray + num_inference_steps: Optional[int] = None + + @classmethod + def create(cls, num_train_timesteps: int): + return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1]) + + +@dataclass +class FlaxSchedulerOutput(SchedulerOutput): + state: DDIMSchedulerState + + +class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): + """ + Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising + diffusion probabilistic models (DDPMs) with non-Markovian guidance. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2010.02502 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`jnp.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + clip_sample (`bool`, default `True`): + option to clip predicted sample between -1 and 1 for numerical stability. + set_alpha_to_one (`bool`, default `True`): + if alpha for final step is 1 or the final alpha of the "non-previous" one. + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[jnp.ndarray] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + ): + if trained_betas is not None: + self.betas = jnp.asarray(trained_betas) + if beta_schedule == "linear": + self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + self.state = DDIMSchedulerState.create(num_train_timesteps=num_train_timesteps) + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def set_timesteps( + self, state: DDIMSchedulerState, num_inference_steps: int, offset: int = 0 + ) -> DDIMSchedulerState: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`DDIMSchedulerState`): + the `FlaxDDIMScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + offset (`int`): + optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference. + """ + step_ratio = self.config.num_train_timesteps // num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] + timesteps = timesteps + offset + + return state.replace(num_inference_steps=num_inference_steps, timesteps=timesteps) + + def step( + self, + state: DDIMSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + key: random.KeyArray, + eta: float = 0.0, + use_clipped_model_output: bool = False, + return_dict: bool = True, + ) -> Union[FlaxSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + state (`DDIMSchedulerState`): the `FlaxDDIMScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + key (`random.KeyArray`): a PRNG key. + eta (`float`): weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`): TODO + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + + """ + if state.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + + # 4. Clip "predicted x_0" + if self.config.clip_sample: + pred_original_sample = jnp.clip(pred_original_sample, -1, 1) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the model_output is always re-derived from the clipped x_0 in Glide + model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if eta > 0: + key = random.split(key, num=1) + noise = random.normal(key=key, shape=model_output.shape) + variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise + + prev_sample = prev_sample + variance + + if not return_dict: + return (prev_sample, state) + + return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) + + def add_noise( + self, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index bdd87f508e..fac75bc43e 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -58,6 +58,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and Langevin dynamics sampling. + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + For more details, see the original paper: https://arxiv.org/abs/2006.11239 Args: @@ -67,7 +72,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, optional): TODO + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. variance_type (`str`): options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. @@ -89,7 +95,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): clip_sample: bool = True, tensor_format: str = "pt", ): - if trained_betas is not None: self.betas = np.asarray(trained_betas) elif beta_schedule == "linear": @@ -143,7 +148,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): if variance_type is None: variance_type = self.config.variance_type - # hacks - were probs added for training stability + # hacks - were probably added for training stability if variance_type == "fixed_small": variance = self.clip(variance, min_value=1e-20) # for rl-diffuser https://arxiv.org/abs/2205.09991 @@ -182,14 +187,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): timestep (`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`): current instance of sample being created by diffusion process. - eta (`float`): weight of noise for added noise in diffusion step. predict_epsilon (`bool`): optional flag to use when model predicts the samples directly instead of the noise, epsilon. generator: random number generator. return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: - `SchedulerOutput`: updated sample in the diffusion chain. + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. """ t = timestep @@ -244,7 +250,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): noise: Union[torch.FloatTensor, np.ndarray], timesteps: Union[torch.IntTensor, np.ndarray], ) -> Union[torch.FloatTensor, np.ndarray]: - + if self.tensor_format == "pt": + timesteps = timesteps.to(self.alphas_cumprod.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py new file mode 100644 index 0000000000..f686a2a322 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -0,0 +1,277 @@ +# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax.numpy as jnp +from jax import random + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return jnp.array(betas, dtype=jnp.float32) + + +@flax.struct.dataclass +class DDPMSchedulerState: + # setable values + timesteps: jnp.ndarray + num_inference_steps: Optional[int] = None + + @classmethod + def create(cls, num_train_timesteps: int): + return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1]) + + +@dataclass +class FlaxSchedulerOutput(SchedulerOutput): + state: DDPMSchedulerState + + +class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin): + """ + Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and + Langevin dynamics sampling. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2006.11239 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + variance_type (`str`): + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + clip_sample (`bool`, default `True`): + option to clip predicted sample between -1 and 1 for numerical stability. + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. + + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[jnp.ndarray] = None, + variance_type: str = "fixed_small", + clip_sample: bool = True, + ): + if trained_betas is not None: + self.betas = jnp.asarray(trained_betas) + elif beta_schedule == "linear": + self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) + self.one = jnp.array(1.0) + + self.state = DDPMSchedulerState.create(num_train_timesteps=num_train_timesteps) + + self.variance_type = variance_type + + def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int) -> DDPMSchedulerState: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`DDIMSchedulerState`): + the `FlaxDDPMScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) + timesteps = jnp.arange( + 0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps + )[::-1] + return state.replace(num_inference_steps=num_inference_steps, timesteps=timesteps) + + def _get_variance(self, t, predicted_variance=None, variance_type=None): + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one + + # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + # and sample from it to get previous sample + # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t] + + if variance_type is None: + variance_type = self.config.variance_type + + # hacks - were probably added for training stability + if variance_type == "fixed_small": + variance = jnp.clip(variance, a_min=1e-20) + # for rl-diffuser https://arxiv.org/abs/2205.09991 + elif variance_type == "fixed_small_log": + variance = jnp.log(jnp.clip(variance, a_min=1e-20)) + elif variance_type == "fixed_large": + variance = self.betas[t] + elif variance_type == "fixed_large_log": + # Glide max_log + variance = jnp.log(self.betas[t]) + elif variance_type == "learned": + return predicted_variance + elif variance_type == "learned_range": + min_log = variance + max_log = self.betas[t] + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + + return variance + + def step( + self, + state: DDPMSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + key: random.KeyArray, + predict_epsilon: bool = True, + return_dict: bool = True, + ) -> Union[FlaxSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + state (`DDPMSchedulerState`): the `FlaxDDPMScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + key (`random.KeyArray`): a PRNG key. + predict_epsilon (`bool`): + optional flag to use when model predicts the samples directly instead of the noise, epsilon. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + + """ + t = timestep + + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: + model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if predict_epsilon: + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + else: + pred_original_sample = model_output + + # 3. Clip "predicted x_0" + if self.config.clip_sample: + pred_original_sample = jnp.clip(pred_original_sample, -1, 1) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t + current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample + + # 6. Add noise + variance = 0 + if t > 0: + key = random.split(key, num=1) + noise = random.normal(key=key, shape=model_output.shape) + variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise + + pred_prev_sample = pred_prev_sample + variance + + if not return_dict: + return (pred_prev_sample, state) + + return FlaxSchedulerOutput(prev_sample=pred_prev_sample, state=state) + + def add_noise( + self, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index 0352be6e3e..caf7625fb6 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -34,7 +34,7 @@ class KarrasVeOutput(BaseOutput): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Derivate of predicted original image sample (x_0). + Derivative of predicted original image sample (x_0). """ prev_sample: torch.FloatTensor @@ -50,6 +50,11 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456 + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. @@ -100,7 +105,10 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): self.num_inference_steps = num_inference_steps self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() self.schedule = [ - (self.sigma_max * (self.sigma_min**2 / self.sigma_max**2) ** (i / (num_inference_steps - 1))) + ( + self.config.sigma_max + * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1)) + ) for i in self.timesteps ] self.schedule = np.array(self.schedule, dtype=np.float32) @@ -116,13 +124,13 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): TODO Args: """ - if self.s_min <= sigma <= self.s_max: - gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1) + if self.config.s_min <= sigma <= self.config.s_max: + gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1) else: gamma = 0 # sample eps ~ N(0, S_noise^2 * I) - eps = self.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device) + eps = self.config.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device) sigma_hat = sigma + gamma * sigma sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) @@ -147,8 +155,11 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - Returns: KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check). + Returns: + [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] or `tuple`: + [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. """ diff --git a/src/diffusers/schedulers/scheduling_karras_ve_flax.py b/src/diffusers/schedulers/scheduling_karras_ve_flax.py new file mode 100644 index 0000000000..fe71b3fd0e --- /dev/null +++ b/src/diffusers/schedulers/scheduling_karras_ve_flax.py @@ -0,0 +1,228 @@ +# Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax.numpy as jnp +from jax import random + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@flax.struct.dataclass +class KarrasVeSchedulerState: + # setable values + num_inference_steps: Optional[int] = None + timesteps: Optional[jnp.ndarray] = None + schedule: Optional[jnp.ndarray] = None # sigma(t_i) + + @classmethod + def create(cls): + return cls() + + +@dataclass +class FlaxKarrasVeOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + derivative (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): + Derivate of predicted original image sample (x_0). + state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. + """ + + prev_sample: jnp.ndarray + derivative: jnp.ndarray + state: KarrasVeSchedulerState + + +class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin): + """ + Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and + the VE column of Table 1 from [1] for reference. + + [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." + https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic + differential equations." https://arxiv.org/abs/2011.13456 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of + Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the + optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper. + + Args: + sigma_min (`float`): minimum noise magnitude + sigma_max (`float`): maximum noise magnitude + s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling. + A reasonable range is [1.000, 1.011]. + s_churn (`float`): the parameter controlling the overall amount of stochasticity. + A reasonable range is [0, 100]. + s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity). + A reasonable range is [0, 10]. + s_max (`float`): the end value of the sigma range where we add noise. + A reasonable range is [0.2, 80]. + """ + + @register_to_config + def __init__( + self, + sigma_min: float = 0.02, + sigma_max: float = 100, + s_noise: float = 1.007, + s_churn: float = 80, + s_min: float = 0.05, + s_max: float = 50, + ): + self.state = KarrasVeSchedulerState.create() + + def set_timesteps(self, state: KarrasVeSchedulerState, num_inference_steps: int) -> KarrasVeSchedulerState: + """ + Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`KarrasVeSchedulerState`): + the `FlaxKarrasVeScheduler` state data class. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + + """ + timesteps = jnp.arange(0, num_inference_steps)[::-1].copy() + schedule = [ + ( + self.config.sigma_max + * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1)) + ) + for i in timesteps + ] + + return state.replace( + num_inference_steps=num_inference_steps, + schedule=jnp.array(schedule, dtype=jnp.float32), + timesteps=timesteps, + ) + + def add_noise_to_input( + self, + state: KarrasVeSchedulerState, + sample: jnp.ndarray, + sigma: float, + key: random.KeyArray, + ) -> Tuple[jnp.ndarray, float]: + """ + Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a + higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. + + TODO Args: + """ + if self.config.s_min <= sigma <= self.config.s_max: + gamma = min(self.config.s_churn / state.num_inference_steps, 2**0.5 - 1) + else: + gamma = 0 + + # sample eps ~ N(0, S_noise^2 * I) + key = random.split(key, num=1) + eps = self.config.s_noise * random.normal(key=key, shape=sample.shape) + sigma_hat = sigma + gamma * sigma + sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps) + + return sample_hat, sigma_hat + + def step( + self, + state: KarrasVeSchedulerState, + model_output: jnp.ndarray, + sigma_hat: float, + sigma_prev: float, + sample_hat: jnp.ndarray, + return_dict: bool = True, + ) -> Union[FlaxKarrasVeOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + sigma_hat (`float`): TODO + sigma_prev (`float`): TODO + sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] or `tuple`: Updated sample in the diffusion + chain and derivative. [`~schedulers.scheduling_karras_ve_flax.FlaxKarrasVeOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + """ + + pred_original_sample = sample_hat + sigma_hat * model_output + derivative = (sample_hat - pred_original_sample) / sigma_hat + sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative + + if not return_dict: + return (sample_prev, derivative, state) + + return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state) + + def step_correct( + self, + state: KarrasVeSchedulerState, + model_output: jnp.ndarray, + sigma_hat: float, + sigma_prev: float, + sample_hat: jnp.ndarray, + sample_prev: jnp.ndarray, + derivative: jnp.ndarray, + return_dict: bool = True, + ) -> Union[FlaxKarrasVeOutput, Tuple]: + """ + Correct the predicted sample based on the output model_output of the network. TODO complete description + + Args: + state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + sigma_hat (`float`): TODO + sigma_prev (`float`): TODO + sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO + sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO + derivative (`torch.FloatTensor` or `np.ndarray`): TODO + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO + + """ + pred_original_sample = sample_prev + sigma_prev * model_output + derivative_corr = (sample_prev - pred_original_sample) / sigma_prev + sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) + + if not return_dict: + return (sample_prev, derivative, state) + + return FlaxKarrasVeOutput(prev_sample=sample_prev, derivative=derivative, state=state) + + def add_noise(self, original_samples, noise, timesteps): + raise NotImplementedError() diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 31d482ae59..5857ae70a8 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -29,6 +29,11 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): Katherine Crowson: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181 + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the starting `beta` value of inference. @@ -36,10 +41,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear` or `scaled_linear`. - trained_betas (`np.ndarray`, optional): TODO + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. - timestep_values (`np.ndarry`, optional): TODO tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. """ @@ -52,7 +57,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, - timestep_values: Optional[np.ndarray] = None, tensor_format: str = "pt", ): if trained_betas is not None: @@ -109,14 +113,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): the number of diffusion steps used when generating samples with a pre-trained model. """ self.num_inference_steps = num_inference_steps - self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) + self.timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) low_idx = np.floor(self.timesteps).astype(int) high_idx = np.ceil(self.timesteps).astype(int) frac = np.mod(self.timesteps, 1.0) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] - self.sigmas = np.concatenate([sigmas, [0.0]]) + self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.derivatives = [] @@ -143,7 +147,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: - prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain. + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. """ sigma = self.sigmas[timestep] @@ -177,6 +183,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): noise: Union[torch.FloatTensor, np.ndarray], timesteps: Union[torch.IntTensor, np.ndarray], ) -> Union[torch.FloatTensor, np.ndarray]: + if self.tensor_format == "pt": + timesteps = timesteps.to(self.sigmas.device) sigmas = self.match_shape(self.sigmas[timesteps], noise) noisy_samples = original_samples + noise * sigmas diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py new file mode 100644 index 0000000000..1431bdacf5 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -0,0 +1,207 @@ +# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax.numpy as jnp +from scipy import integrate + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +@flax.struct.dataclass +class LMSDiscreteSchedulerState: + # setable values + num_inference_steps: Optional[int] = None + timesteps: Optional[jnp.ndarray] = None + sigmas: Optional[jnp.ndarray] = None + derivatives: jnp.ndarray = jnp.array([]) + + @classmethod + def create(cls, num_train_timesteps: int, sigmas: jnp.ndarray): + return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1], sigmas=sigmas) + + +@dataclass +class FlaxSchedulerOutput(SchedulerOutput): + state: LMSDiscreteSchedulerState + + +class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by + Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`jnp.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[jnp.ndarray] = None, + ): + if trained_betas is not None: + self.betas = jnp.asarray(trained_betas) + if beta_schedule == "linear": + self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2 + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) + + self.state = LMSDiscreteSchedulerState.create( + num_train_timesteps=num_train_timesteps, sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + ) + + def get_lms_coefficient(self, state, order, t, current_order): + """ + Compute a linear multistep coefficient. + + Args: + order (TODO): + t (TODO): + current_order (TODO): + """ + + def lms_derivative(tau): + prod = 1.0 + for k in range(order): + if current_order == k: + continue + prod *= (tau - state.sigmas[t - k]) / (state.sigmas[t - current_order] - state.sigmas[t - k]) + return prod + + integrated_coeff = integrate.quad(lms_derivative, state.sigmas[t], state.sigmas[t + 1], epsrel=1e-4)[0] + + return integrated_coeff + + def set_timesteps(self, state: LMSDiscreteSchedulerState, num_inference_steps: int) -> LMSDiscreteSchedulerState: + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`LMSDiscreteSchedulerState`): + the `FlaxLMSDiscreteScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=jnp.float32) + + low_idx = jnp.floor(timesteps).astype(int) + high_idx = jnp.ceil(timesteps).astype(int) + frac = jnp.mod(timesteps, 1.0) + sigmas = jnp.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] + sigmas = jnp.concatenate([sigmas, jnp.array([0.0])]).astype(jnp.float32) + + return state.replace( + num_inference_steps=num_inference_steps, + timesteps=timesteps, + derivatives=jnp.array([]), + sigmas=sigmas, + ) + + def step( + self, + state: LMSDiscreteSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + order: int = 4, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + state (`LMSDiscreteSchedulerState`): the `FlaxLMSDiscreteScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + order: coefficient for multi-step inference. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + + """ + sigma = state.sigmas[timestep] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + pred_original_sample = sample - sigma * model_output + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + state = state.replace(derivatives=state.derivatives.append(derivative)) + if len(state.derivatives) > order: + state = state.replace(derivatives=state.derivatives.pop(0)) + + # 3. Compute linear multistep coefficients + order = min(timestep + 1, order) + lms_coeffs = [self.get_lms_coefficient(state, order, timestep, curr_order) for curr_order in range(order)] + + # 4. Compute previous sample based on the derivatives path + prev_sample = sample + sum( + coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(state.derivatives)) + ) + + if not return_dict: + return (prev_sample, state) + + return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) + + def add_noise( + self, + state: LMSDiscreteSchedulerState, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + sigmas = self.match_shape(state.sigmas[timesteps], noise) + noisy_samples = original_samples + noise * sigmas + + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 171b509898..09e8a7e240 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -15,6 +15,7 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math +import warnings from typing import Optional, Tuple, Union import numpy as np @@ -58,6 +59,11 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, namely Runge-Kutta method and a linear multi-step method. + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + For more details, see the original paper: https://arxiv.org/abs/2202.09778 Args: @@ -67,11 +73,20 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, optional): TODO - tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. skip_prk_steps (`bool`): allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required before plms steps; defaults to `False`. + set_alpha_to_one (`bool`, default `False`): + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays """ @@ -83,8 +98,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, - tensor_format: str = "pt", skip_prk_steps: bool = False, + set_alpha_to_one: bool = False, + steps_offset: int = 0, + tensor_format: str = "pt", ): if trained_betas is not None: self.betas = np.asarray(trained_betas) @@ -102,7 +119,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): self.alphas = 1.0 - self.betas self.alphas_cumprod = np.cumprod(self.alphas, axis=0) - self.one = np.array(1.0) + self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] # For now we only support F-PNDM, i.e. the runge-kutta method # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf @@ -118,7 +135,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): # setable values self.num_inference_steps = None self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy() - self._offset = 0 self.prk_timesteps = None self.plms_timesteps = None self.timesteps = None @@ -126,21 +142,31 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): self.tensor_format = tensor_format self.set_format(tensor_format=tensor_format) - def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor: + def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. - offset (`int`): TODO """ + + offset = self.config.steps_offset + + if "offset" in kwargs: + warnings.warn( + "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." + " Please pass `steps_offset` to `__init__` instead." + ) + + offset = kwargs["offset"] + self.num_inference_steps = num_inference_steps - self._timesteps = list( - range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps) - ) - self._offset = offset - self._timesteps = np.array([t + self._offset for t in self._timesteps]) + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round() + self._timesteps += offset if self.config.skip_prk_steps: # for some models like stable diffusion the prk steps can/should be skipped to @@ -186,7 +212,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: - `SchedulerOutput`: updated sample in the diffusion chain. + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. """ if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps: @@ -213,7 +241,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: - prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain. + [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: @@ -222,7 +251,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ) diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2 - prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1]) + prev_timestep = timestep - diff_to_prev timestep = self.prk_timesteps[self.counter // 4 * 4] if self.counter % 4 == 0: @@ -267,7 +296,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: - prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain. + [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if self.num_inference_steps is None: @@ -283,7 +313,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): "for more information." ) - prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps if self.counter != 1: self.ets.append(model_output) @@ -313,7 +343,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): return SchedulerOutput(prev_sample=prev_sample) - def _get_prev_sample(self, sample, timestep, timestep_prev, model_output): + def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf # this function computes x_(t−δ) using the formula of (9) # Note that x_t needs to be added to both sides of the equation @@ -326,8 +356,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): # sample -> x_t # model_output -> e_θ(x_t, t) # prev_sample -> x_(t−δ) - alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset] - alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - self._offset] + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev @@ -355,7 +385,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): noise: Union[torch.FloatTensor, np.ndarray], timesteps: Union[torch.IntTensor, np.ndarray], ) -> torch.Tensor: - + if self.tensor_format == "pt": + timesteps = timesteps.to(self.alphas_cumprod.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py new file mode 100644 index 0000000000..8444d66804 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -0,0 +1,406 @@ +# Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax.numpy as jnp + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray: + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return jnp.array(betas, dtype=jnp.float32) + + +@flax.struct.dataclass +class PNDMSchedulerState: + # setable values + _timesteps: jnp.ndarray + num_inference_steps: Optional[int] = None + _offset: int = 0 + prk_timesteps: Optional[jnp.ndarray] = None + plms_timesteps: Optional[jnp.ndarray] = None + timesteps: Optional[jnp.ndarray] = None + + # running values + cur_model_output: Optional[jnp.ndarray] = None + counter: int = 0 + cur_sample: Optional[jnp.ndarray] = None + ets: jnp.ndarray = jnp.array([]) + + @classmethod + def create(cls, num_train_timesteps: int): + return cls(_timesteps=jnp.arange(0, num_train_timesteps)[::-1]) + + +@dataclass +class FlaxSchedulerOutput(SchedulerOutput): + state: PNDMSchedulerState + + +class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): + """ + Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, + namely Runge-Kutta method and a linear multi-step method. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2202.09778 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`jnp.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + skip_prk_steps (`bool`): + allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required + before plms steps; defaults to `False`. + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[jnp.ndarray] = None, + skip_prk_steps: bool = False, + ): + if trained_betas is not None: + self.betas = jnp.asarray(trained_betas) + if beta_schedule == "linear": + self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) + + # For now we only support F-PNDM, i.e. the runge-kutta method + # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf + # mainly at formula (9), (12), (13) and the Algorithm 2. + self.pndm_order = 4 + + self.state = PNDMSchedulerState.create(num_train_timesteps=num_train_timesteps) + + def set_timesteps( + self, state: PNDMSchedulerState, num_inference_steps: int, offset: int = 0 + ) -> PNDMSchedulerState: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`PNDMSchedulerState`): + the `FlaxPNDMScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + offset (`int`): + optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference. + """ + step_ratio = self.config.num_train_timesteps // num_inference_steps + # creates integer timesteps by multiplying by ratio + # rounding to avoid issues when num_inference_step is power of 3 + _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] + _timesteps = _timesteps + offset + + state = state.replace(num_inference_steps=num_inference_steps, _offset=offset, _timesteps=_timesteps) + + if self.config.skip_prk_steps: + # for some models like stable diffusion the prk steps can/should be skipped to + # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation + # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 + state = state.replace( + prk_timesteps=jnp.array([]), + plms_timesteps=jnp.concatenate( + [state._timesteps[:-1], state._timesteps[-2:-1], state._timesteps[-1:]] + )[::-1], + ) + else: + prk_timesteps = jnp.array(state._timesteps[-self.pndm_order :]).repeat(2) + jnp.tile( + jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order + ) + + state = state.replace( + prk_timesteps=(prk_timesteps[:-1].repeat(2)[1:-1])[::-1], + plms_timesteps=state._timesteps[:-3][::-1], + ) + + return state.replace( + timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int64), + ets=jnp.array([]), + counter=0, + ) + + def step( + self, + state: PNDMSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + return_dict: bool = True, + ) -> Union[FlaxSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. + + Args: + state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + + """ + if state.counter < len(state.prk_timesteps) and not self.config.skip_prk_steps: + return self.step_prk( + state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict + ) + else: + return self.step_plms( + state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict + ) + + def step_prk( + self, + state: PNDMSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + return_dict: bool = True, + ) -> Union[FlaxSchedulerOutput, Tuple]: + """ + Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the + solution to the differential equation. + + Args: + state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + + """ + if state.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2 + prev_timestep = max(timestep - diff_to_prev, state.prk_timesteps[-1]) + timestep = state.prk_timesteps[state.counter // 4 * 4] + + if state.counter % 4 == 0: + state = state.replace( + cur_model_output=state.cur_model_output + 1 / 6 * model_output, + ets=state.ets.append(model_output), + cur_sample=sample, + ) + elif (self.counter - 1) % 4 == 0: + state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output) + elif (self.counter - 2) % 4 == 0: + state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output) + elif (self.counter - 3) % 4 == 0: + model_output = state.cur_model_output + 1 / 6 * model_output + state = state.replace(cur_model_output=0) + + # cur_sample should not be `None` + cur_sample = state.cur_sample if state.cur_sample is not None else sample + + prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output, state=state) + state = state.replace(counter=state.counter + 1) + + if not return_dict: + return (prev_sample, state) + + return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) + + def step_plms( + self, + state: PNDMSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + return_dict: bool = True, + ) -> Union[FlaxSchedulerOutput, Tuple]: + """ + Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple + times to approximate the solution. + + Args: + state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + + """ + if state.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if not self.config.skip_prk_steps and len(state.ets) < 3: + raise ValueError( + f"{self.__class__} can only be run AFTER scheduler has been run " + "in 'prk' mode for at least 12 iterations " + "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py " + "for more information." + ) + + prev_timestep = max(timestep - self.config.num_train_timesteps // state.num_inference_steps, 0) + + if state.counter != 1: + state = state.replace(ets=state.ets.append(model_output)) + else: + prev_timestep = timestep + timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps + + if len(state.ets) == 1 and state.counter == 0: + model_output = model_output + state = state.replace(cur_sample=sample) + elif len(state.ets) == 1 and state.counter == 1: + model_output = (model_output + state.ets[-1]) / 2 + sample = state.cur_sample + state = state.replace(cur_sample=None) + elif len(state.ets) == 2: + model_output = (3 * state.ets[-1] - state.ets[-2]) / 2 + elif len(state.ets) == 3: + model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12 + else: + model_output = (1 / 24) * ( + 55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4] + ) + + prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output, state=state) + state = state.replace(counter=state.counter + 1) + + if not return_dict: + return (prev_sample, state) + + return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) + + def _get_prev_sample(self, sample, timestep, timestep_prev, model_output, state): + # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf + # this function computes x_(t−δ) using the formula of (9) + # Note that x_t needs to be added to both sides of the equation + + # Notation ( -> + # alpha_prod_t -> α_t + # alpha_prod_t_prev -> α_(t−δ) + # beta_prod_t -> (1 - α_t) + # beta_prod_t_prev -> (1 - α_(t−δ)) + # sample -> x_t + # model_output -> e_θ(x_t, t) + # prev_sample -> x_(t−δ) + alpha_prod_t = self.alphas_cumprod[timestep + 1 - state._offset] + alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - state._offset] + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # corresponds to (α_(t−δ) - α_t) divided by + # denominator of x_t in formula (9) and plus 1 + # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = + # sqrt(α_(t−δ)) / sqrt(α_t)) + sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) + + # corresponds to denominator of e_θ(x_t, t) in formula (9) + model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( + alpha_prod_t * beta_prod_t * alpha_prod_t_prev + ) ** (0.5) + + # full formula (9) + prev_sample = ( + sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff + ) + + return prev_sample + + def add_noise( + self, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 308f42c91f..4af8f4fdad 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -49,14 +49,20 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): For more information, see the original paper: https://arxiv.org/abs/2011.13456 + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. snr (`float`): coefficient weighting the step from the model_output sample (from the network) to the random noise. sigma_min (`float`): initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the distribution of the data. sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model. - sampling_eps (`float`): the end value of sampling, where timesteps decrease progessively from 1 to + sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to epsilon. correct_steps (`int`): number of correction steps performed on a produced sample. tensor_format (`str`): "np" or "pt" for the expected format of samples passed to the Scheduler. @@ -139,7 +145,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1]) elif tensor_format == "pt": return torch.where( - timesteps == 0, torch.zeros_like(t), self.discrete_sigmas[timesteps - 1].to(timesteps.device) + timesteps == 0, + torch.zeros_like(t.to(timesteps.device)), + self.discrete_sigmas[timesteps - 1].to(timesteps.device), ) raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") @@ -180,7 +188,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: - prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain. + [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if "seed" in kwargs and kwargs["seed"] is not None: @@ -196,8 +205,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ) # torch.repeat_interleave(timestep, sample.shape[0]) timesteps = (timestep * (len(self.timesteps) - 1)).long() + # mps requires indices to be in the same device, so we use cpu as is the default with cuda + timesteps = timesteps.to(self.discrete_sigmas.device) + sigma = self.discrete_sigmas[timesteps].to(sample.device) - adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep) + adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device) drift = self.zeros_like(sample) diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 @@ -236,7 +248,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: - prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain. + [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ if "seed" in kwargs and kwargs["seed"] is not None: diff --git a/src/diffusers/schedulers/scheduling_sde_ve_flax.py b/src/diffusers/schedulers/scheduling_sde_ve_flax.py new file mode 100644 index 0000000000..e5860706aa --- /dev/null +++ b/src/diffusers/schedulers/scheduling_sde_ve_flax.py @@ -0,0 +1,260 @@ +# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import flax +import jax.numpy as jnp +from jax import random + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +@flax.struct.dataclass +class ScoreSdeVeSchedulerState: + # setable values + timesteps: Optional[jnp.ndarray] = None + discrete_sigmas: Optional[jnp.ndarray] = None + sigmas: Optional[jnp.ndarray] = None + + @classmethod + def create(cls): + return cls() + + +@dataclass +class FlaxSdeVeOutput(SchedulerOutput): + """ + Output class for the ScoreSdeVeScheduler's step function output. + + Args: + state (`ScoreSdeVeSchedulerState`): + prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + prev_sample_mean (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): + Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps. + """ + + state: ScoreSdeVeSchedulerState + prev_sample: jnp.ndarray + prev_sample_mean: Optional[jnp.ndarray] = None + + +class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): + """ + The variance exploding stochastic differential equation (SDE) scheduler. + + For more information, see the original paper: https://arxiv.org/abs/2011.13456 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + snr (`float`): + coefficient weighting the step from the model_output sample (from the network) to the random noise. + sigma_min (`float`): + initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the + distribution of the data. + sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model. + sampling_eps (`float`): the end value of sampling, where timesteps decrease progressively from 1 to + epsilon. + correct_steps (`int`): number of correction steps performed on a produced sample. + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 2000, + snr: float = 0.15, + sigma_min: float = 0.01, + sigma_max: float = 1348.0, + sampling_eps: float = 1e-5, + correct_steps: int = 1, + ): + state = ScoreSdeVeSchedulerState.create() + + self.state = self.set_sigmas(state, num_train_timesteps, sigma_min, sigma_max, sampling_eps) + + def set_timesteps( + self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, sampling_eps: float = None + ) -> ScoreSdeVeSchedulerState: + """ + Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). + + """ + sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps + + timesteps = jnp.linspace(1, sampling_eps, num_inference_steps) + return state.replace(timesteps=timesteps) + + def set_sigmas( + self, + state: ScoreSdeVeSchedulerState, + num_inference_steps: int, + sigma_min: float = None, + sigma_max: float = None, + sampling_eps: float = None, + ) -> ScoreSdeVeSchedulerState: + """ + Sets the noise scales used for the diffusion chain. Supporting function to be run before inference. + + The sigmas control the weight of the `drift` and `diffusion` components of sample update. + + Args: + state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + sigma_min (`float`, optional): + initial noise scale value (overrides value given at Scheduler instantiation). + sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation). + sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). + """ + sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min + sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max + sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps + if state.timesteps is None: + state = self.set_timesteps(state, num_inference_steps, sampling_eps) + + discrete_sigmas = jnp.exp(jnp.linspace(jnp.log(sigma_min), jnp.log(sigma_max), num_inference_steps)) + sigmas = jnp.array([sigma_min * (sigma_max / sigma_min) ** t for t in state.timesteps]) + + return state.replace(discrete_sigmas=discrete_sigmas, sigmas=sigmas) + + def get_adjacent_sigma(self, state, timesteps, t): + return jnp.where(timesteps == 0, jnp.zeros_like(t), state.discrete_sigmas[timesteps - 1]) + + def step_pred( + self, + state: ScoreSdeVeSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + key: random.KeyArray, + return_dict: bool = True, + ) -> Union[FlaxSdeVeOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + if state.timesteps is None: + raise ValueError( + "`state.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" + ) + + timestep = timestep * jnp.ones( + sample.shape[0], + ) + timesteps = (timestep * (len(state.timesteps) - 1)).long() + + sigma = state.discrete_sigmas[timesteps] + adjacent_sigma = self.get_adjacent_sigma(state, timesteps, timestep) + drift = jnp.zeros_like(sample) + diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 + + # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) + # also equation 47 shows the analog from SDE models to ancestral sampling methods + drift = drift - diffusion[:, None, None, None] ** 2 * model_output + + # equation 6: sample noise for the diffusion term of + key = random.split(key, num=1) + noise = random.normal(key=key, shape=sample.shape) + prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep + # TODO is the variable diffusion the correct scaling term for the noise? + prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g + + if not return_dict: + return (prev_sample, prev_sample_mean, state) + + return FlaxSdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean, state=state) + + def step_correct( + self, + state: ScoreSdeVeSchedulerState, + model_output: jnp.ndarray, + sample: jnp.ndarray, + key: random.KeyArray, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Correct the predicted sample based on the output model_output of the network. This is often run repeatedly + after making the prediction for the previous timestep. + + Args: + state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + generator: random number generator. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`FlaxSdeVeOutput`] or `tuple`: [`FlaxSdeVeOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + if state.timesteps is None: + raise ValueError( + "`state.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler" + ) + + # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z" + # sample noise for correction + key = random.split(key, num=1) + noise = random.normal(key=key, shape=sample.shape) + + # compute step size from the model_output, the noise, and the snr + grad_norm = jnp.linalg.norm(model_output) + noise_norm = jnp.linalg.norm(noise) + step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 + step_size = step_size * jnp.ones(sample.shape[0]) + + # compute corrected sample: model_output term and noise term + prev_sample_mean = sample + step_size[:, None, None, None] * model_output + prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise + + if not return_dict: + return (prev_sample, state) + + return FlaxSdeVeOutput(prev_sample=prev_sample, state=state) + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index d3482f4b00..f19a5ad76f 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -14,7 +14,7 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch -# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit +# TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit import numpy as np import torch @@ -27,6 +27,11 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): """ The variance preserving stochastic differential equation (SDE) scheduler. + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + For more information, see the original paper: https://arxiv.org/abs/2011.13456 UNDER CONSTRUCTION @@ -35,7 +40,6 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): @register_to_config def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"): - self.sigmas = None self.discrete_sigmas = None self.timesteps = None diff --git a/src/diffusers/testing_utils.py b/src/diffusers/testing_utils.py index 13f6332a94..7daf2bc633 100644 --- a/src/diffusers/testing_utils.py +++ b/src/diffusers/testing_utils.py @@ -2,12 +2,22 @@ import os import random import unittest from distutils.util import strtobool +from typing import Union import torch +import PIL.Image +import PIL.ImageOps +import requests +from packaging import version + global_rng = random.Random() torch_device = "cuda" if torch.cuda.is_available() else "cpu" +is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12") + +if is_torch_higher_equal_than_1_12: + torch_device = "mps" if torch.backends.mps.is_available() else torch_device def parse_flag_from_env(key, default=False): @@ -53,3 +63,32 @@ def slow(test_case): """ return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) + + +def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: + """ + Args: + Loads `image` to a PIL Image. + image (`str` or `PIL.Image.Image`): + The image to convert to the PIL Image format. + Returns: + `PIL.Image.Image`: A PIL Image. + """ + if isinstance(image, str): + if image.startswith("http://") or image.startswith("https://"): + image = PIL.Image.open(requests.get(image, stream=True).raw) + elif os.path.isfile(image): + image = PIL.Image.open(image) + else: + raise ValueError( + f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path" + ) + elif isinstance(image, PIL.Image.Image): + image = image + else: + raise ValueError( + "Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image." + ) + image = PIL.ImageOps.exif_transpose(image) + image = image.convert("RGB") + return image diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index f9172e8dc9..c00a28e105 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -25,6 +25,7 @@ from .import_utils import ( is_flax_available, is_inflect_available, is_modelcards_available, + is_onnx_available, is_scipy_available, is_tf_available, is_torch_available, diff --git a/src/diffusers/utils/dummy_flax_objects.py b/src/diffusers/utils/dummy_flax_objects.py new file mode 100644 index 0000000000..9615afb6f9 --- /dev/null +++ b/src/diffusers/utils/dummy_flax_objects.py @@ -0,0 +1,60 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +# flake8: noqa + +from ..utils import DummyObject, requires_backends + + +class FlaxModelMixin(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxDDIMScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxDDPMScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxKarrasVeScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxLMSDiscreteScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxPNDMScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxUNet2DConditionModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxScoreSdeVeScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py new file mode 100644 index 0000000000..531c0b7766 --- /dev/null +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -0,0 +1,165 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +# flake8: noqa + +from ..utils import DummyObject, requires_backends + + +class ModelMixin(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class AutoencoderKL(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class UNet2DConditionModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class UNet2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class VQModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +def get_constant_schedule(*args, **kwargs): + requires_backends(get_constant_schedule, ["torch"]) + + +def get_constant_schedule_with_warmup(*args, **kwargs): + requires_backends(get_constant_schedule_with_warmup, ["torch"]) + + +def get_cosine_schedule_with_warmup(*args, **kwargs): + requires_backends(get_cosine_schedule_with_warmup, ["torch"]) + + +def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs): + requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch"]) + + +def get_linear_schedule_with_warmup(*args, **kwargs): + requires_backends(get_linear_schedule_with_warmup, ["torch"]) + + +def get_polynomial_decay_schedule_with_warmup(*args, **kwargs): + requires_backends(get_polynomial_decay_schedule_with_warmup, ["torch"]) + + +def get_scheduler(*args, **kwargs): + requires_backends(get_scheduler, ["torch"]) + + +class DiffusionPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class DDIMPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class DDPMPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class KarrasVePipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class LDMPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class PNDMPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ScoreSdeVePipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class DDIMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class DDPMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class KarrasVeScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class PNDMScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class SchedulerMixin(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class ScoreSdeVeScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class EMAModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) diff --git a/src/diffusers/utils/dummy_scipy_objects.py b/src/diffusers/utils/dummy_torch_and_scipy_objects.py similarity index 73% rename from src/diffusers/utils/dummy_scipy_objects.py rename to src/diffusers/utils/dummy_torch_and_scipy_objects.py index 3706c57541..49c8956483 100644 --- a/src/diffusers/utils/dummy_scipy_objects.py +++ b/src/diffusers/utils/dummy_torch_and_scipy_objects.py @@ -5,7 +5,7 @@ from ..utils import DummyObject, requires_backends class LMSDiscreteScheduler(metaclass=DummyObject): - _backends = ["scipy"] + _backends = ["torch", "scipy"] def __init__(self, *args, **kwargs): - requires_backends(self, ["scipy"]) + requires_backends(self, ["torch", "scipy"]) diff --git a/src/diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py similarity index 51% rename from src/diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py rename to src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py index 8c2aec218c..967e231d87 100644 --- a/src/diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_and_onnx_objects.py @@ -1,10 +1,11 @@ # This file is autogenerated by the command `make fix-copies`, do not edit. # flake8: noqa + from ..utils import DummyObject, requires_backends -class GradTTSPipeline(metaclass=DummyObject): - _backends = ["transformers", "inflect", "unidecode"] +class StableDiffusionOnnxPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "onnx"] def __init__(self, *args, **kwargs): - requires_backends(self, ["transformers", "inflect", "unidecode"]) + requires_backends(self, ["torch", "transformers", "onnx"]) diff --git a/src/diffusers/utils/dummy_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py similarity index 57% rename from src/diffusers/utils/dummy_transformers_objects.py rename to src/diffusers/utils/dummy_torch_and_transformers_objects.py index e05eb814d1..6e4ab48c33 100644 --- a/src/diffusers/utils/dummy_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -5,28 +5,28 @@ from ..utils import DummyObject, requires_backends class LDMTextToImagePipeline(metaclass=DummyObject): - _backends = ["transformers"] + _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): - requires_backends(self, ["transformers"]) + requires_backends(self, ["torch", "transformers"]) class StableDiffusionImg2ImgPipeline(metaclass=DummyObject): - _backends = ["transformers"] + _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): - requires_backends(self, ["transformers"]) + requires_backends(self, ["torch", "transformers"]) class StableDiffusionInpaintPipeline(metaclass=DummyObject): - _backends = ["transformers"] + _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): - requires_backends(self, ["transformers"]) + requires_backends(self, ["torch", "transformers"]) class StableDiffusionPipeline(metaclass=DummyObject): - _backends = ["transformers"] + _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): - requires_backends(self, ["transformers"]) + requires_backends(self, ["torch", "transformers"]) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 05068b6deb..de344d074d 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -136,6 +136,22 @@ except importlib_metadata.PackageNotFoundError: _modelcards_available = False +_onnx_available = importlib.util.find_spec("onnxruntime") is not None +if _onnx_available: + candidates = ("onnxruntime", "onnxruntime-gpu", "onnxruntime-directml", "onnxruntime-openvino") + _onnxruntime_version = None + # For the metadata, we have to look for both onnxruntime and onnxruntime-gpu + for pkg in candidates: + try: + _onnxruntime_version = importlib_metadata.version(pkg) + break + except importlib_metadata.PackageNotFoundError: + pass + _onnx_available = _onnxruntime_version is not None + if _onnx_available: + logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}") + + _scipy_available = importlib.util.find_spec("scipy") is not None try: _scipy_version = importlib_metadata.version("scipy") @@ -172,6 +188,10 @@ def is_modelcards_available(): return _modelcards_available +def is_onnx_available(): + return _onnx_available + + def is_scipy_available(): return _scipy_available @@ -194,6 +214,12 @@ PYTORCH_IMPORT_ERROR = """ installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. """ +# docstyle-ignore +ONNX_IMPORT_ERROR = """ +{0} requires the onnxruntime library but it was not found in your environment. You can install it with pip: `pip +install onnxruntime` +""" + # docstyle-ignore SCIPY_IMPORT_ERROR = """ {0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install @@ -223,6 +249,7 @@ BACKENDS_MAPPING = OrderedDict( [ ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)), + ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)), ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), diff --git a/src/diffusers/utils/logging.py b/src/diffusers/utils/logging.py index 1f2d0227b8..7771a5a5bf 100644 --- a/src/diffusers/utils/logging.py +++ b/src/diffusers/utils/logging.py @@ -65,17 +65,14 @@ def _get_default_logging_level(): def _get_library_name() -> str: - return __name__.split(".")[0] def _get_library_root_logger() -> logging.Logger: - return logging.getLogger(_get_library_name()) def _configure_library_root_logger() -> None: - global _default_handler with _lock: @@ -93,7 +90,6 @@ def _configure_library_root_logger() -> None: def _reset_library_root_logger() -> None: - global _default_handler with _lock: diff --git a/src/diffusers/utils/outputs.py b/src/diffusers/utils/outputs.py index d8e695db59..b02f62d02d 100644 --- a/src/diffusers/utils/outputs.py +++ b/src/diffusers/utils/outputs.py @@ -59,40 +59,10 @@ class BaseOutput(OrderedDict): if not len(class_fields): raise ValueError(f"{self.__class__.__name__} has no fields.") - first_field = getattr(self, class_fields[0].name) - other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) - - if other_fields_are_none and not is_tensor(first_field): - if isinstance(first_field, dict): - iterator = first_field.items() - first_field_iterator = True - else: - try: - iterator = iter(first_field) - first_field_iterator = True - except TypeError: - first_field_iterator = False - - # if we provided an iterator as first field and the iterator is a (key, value) iterator - # set the associated fields - if first_field_iterator: - for element in iterator: - if ( - not isinstance(element, (list, tuple)) - or not len(element) == 2 - or not isinstance(element[0], str) - ): - break - setattr(self, element[0], element[1]) - if element[1] is not None: - self[element[0]] = element[1] - elif first_field is not None: - self[class_fields[0].name] = first_field - else: - for field in class_fields: - v = getattr(self, field.name) - if v is not None: - self[field.name] = v + for field in class_fields: + v = getattr(self, field.name) + if v is not None: + self[field.name] = v def __delitem__(self, *args, **kwargs): raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py index a94ecd58d5..4c9b17caa7 100755 --- a/tests/test_layers_utils.py +++ b/tests/test_layers_utils.py @@ -19,8 +19,10 @@ import unittest import numpy as np import torch +from diffusers.models.attention import AttentionBlock, SpatialTransformer from diffusers.models.embeddings import get_timestep_embedding from diffusers.models.resnet import Downsample2D, Upsample2D +from diffusers.testing_utils import torch_device torch.backends.cuda.matmul.allow_tf32 = False @@ -216,3 +218,108 @@ class Downsample2DBlockTests(unittest.TestCase): output_slice = downsampled[0, -1, -3:, -3:] expected_slice = torch.tensor([-0.6586, 0.5985, 0.0721, 0.1256, -0.1492, 0.4436, -0.2544, 0.5021, 1.1522]) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + +class AttentionBlockTests(unittest.TestCase): + def test_attention_block_default(self): + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + sample = torch.randn(1, 32, 64, 64).to(torch_device) + attentionBlock = AttentionBlock( + channels=32, + num_head_channels=1, + rescale_output_factor=1.0, + eps=1e-6, + num_groups=32, + ).to(torch_device) + with torch.no_grad(): + attention_scores = attentionBlock(sample) + + assert attention_scores.shape == (1, 32, 64, 64) + output_slice = attention_scores[0, -1, -3:, -3:] + + expected_slice = torch.tensor( + [-1.4975, -0.0038, -0.7847, -1.4567, 1.1220, -0.8962, -1.7394, 1.1319, -0.5427], device=torch_device + ) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + +class SpatialTransformerTests(unittest.TestCase): + def test_spatial_transformer_default(self): + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + sample = torch.randn(1, 32, 64, 64).to(torch_device) + spatial_transformer_block = SpatialTransformer( + in_channels=32, + n_heads=1, + d_head=32, + dropout=0.0, + context_dim=None, + ).to(torch_device) + with torch.no_grad(): + attention_scores = spatial_transformer_block(sample) + + assert attention_scores.shape == (1, 32, 64, 64) + output_slice = attention_scores[0, -1, -3:, -3:] + + expected_slice = torch.tensor( + [-1.2447, -0.0137, -0.9559, -1.5223, 0.6991, -1.0126, -2.0974, 0.8921, -1.0201], device=torch_device + ) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + def test_spatial_transformer_context_dim(self): + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + sample = torch.randn(1, 64, 64, 64).to(torch_device) + spatial_transformer_block = SpatialTransformer( + in_channels=64, + n_heads=2, + d_head=32, + dropout=0.0, + context_dim=64, + ).to(torch_device) + with torch.no_grad(): + context = torch.randn(1, 4, 64).to(torch_device) + attention_scores = spatial_transformer_block(sample, context) + + assert attention_scores.shape == (1, 64, 64, 64) + output_slice = attention_scores[0, -1, -3:, -3:] + + expected_slice = torch.tensor( + [-0.2555, -0.8877, -2.4739, -2.2251, 1.2714, 0.0807, -0.4161, -1.6408, -0.0471], device=torch_device + ) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + def test_spatial_transformer_dropout(self): + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + sample = torch.randn(1, 32, 64, 64).to(torch_device) + spatial_transformer_block = ( + SpatialTransformer( + in_channels=32, + n_heads=2, + d_head=16, + dropout=0.3, + context_dim=None, + ) + .to(torch_device) + .eval() + ) + with torch.no_grad(): + attention_scores = spatial_transformer_block(sample) + + assert attention_scores.shape == (1, 32, 64, 64) + output_slice = attention_scores[0, -1, -3:, -3:] + + expected_slice = torch.tensor( + [-1.2448, -0.0190, -0.9471, -1.5140, 0.7069, -1.0144, -2.1077, 0.9099, -1.0091], device=torch_device + ) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 8c7c6312de..1e98fc9de7 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -15,11 +15,13 @@ import inspect import tempfile +import unittest from typing import Dict, List, Tuple import numpy as np import torch +from diffusers.modeling_utils import ModelMixin from diffusers.testing_utils import torch_device from diffusers.training_utils import EMAModel @@ -38,6 +40,11 @@ class ModelTesterMixin: new_model.to(torch_device) with torch.no_grad(): + # Warmup pass when using mps (see #372) + if torch_device == "mps" and isinstance(model, ModelMixin): + _ = model(**self.dummy_input) + _ = new_model(**self.dummy_input) + image = model(**inputs_dict) if isinstance(image, dict): image = image.sample @@ -55,7 +62,12 @@ class ModelTesterMixin: model = self.model_class(**init_dict) model.to(torch_device) model.eval() + with torch.no_grad(): + # Warmup pass when using mps (see #372) + if torch_device == "mps" and isinstance(model, ModelMixin): + model(**self.dummy_input) + first = model(**inputs_dict) if isinstance(first, dict): first = first.sample @@ -87,6 +99,26 @@ class ModelTesterMixin: expected_shape = inputs_dict["sample"].shape self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + def test_forward_with_norm_groups(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["norm_num_groups"] = 16 + init_dict["block_out_channels"] = (16, 32) + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + def test_forward_signature(self): init_dict, _ = self.prepare_init_args_and_inputs_for_common() @@ -113,7 +145,7 @@ class ModelTesterMixin: new_model.to(torch_device) new_model.eval() - # check if all paramters shape are the same + # check if all parameters shape are the same for param_name in model.state_dict().keys(): param_1 = model.state_dict()[param_name] param_2 = new_model.state_dict()[param_name] @@ -132,6 +164,7 @@ class ModelTesterMixin: self.assertEqual(output_1.shape, output_2.shape) + @unittest.skipIf(torch_device == "mps", "Training is not supported in mps") def test_training(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -147,6 +180,7 @@ class ModelTesterMixin: loss = torch.nn.functional.mse_loss(output, noise) loss.backward() + @unittest.skipIf(torch_device == "mps", "Training is not supported in mps") def test_ema_training(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -165,10 +199,15 @@ class ModelTesterMixin: loss.backward() ema_model.step(model) - def test_scheduler_outputs_equivalence(self): + def test_outputs_equivalence(self): def set_nan_tensor_to_zero(t): + # Temporary fallback until `aten::_index_put_impl_` is implemented in mps + # Track progress in https://github.com/pytorch/pytorch/issues/77764 + device = t.device + if device.type == "mps": + t = t.to("cpu") t[t != t] = 0 - return t + return t.to(device) def recursive_check(tuple_object, dict_object): if isinstance(tuple_object, (List, Tuple)): @@ -198,7 +237,12 @@ class ModelTesterMixin: model.to(torch_device) model.eval() - outputs_dict = model(**inputs_dict) - outputs_tuple = model(**inputs_dict, return_dict=False) + with torch.no_grad(): + # Warmup pass when using mps (see #372) + if torch_device == "mps" and isinstance(model, ModelMixin): + model(**self.dummy_input) + + outputs_dict = model(**inputs_dict) + outputs_tuple = model(**inputs_dict, return_dict=False) recursive_check(outputs_tuple, outputs_dict) diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index c574a0092e..b16a4e1c44 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -138,11 +138,13 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): model.eval() model.to(torch_device) - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + noise = torch.randn( + 1, + model.config.in_channels, + model.config.sample_size, + model.config.sample_size, + generator=torch.manual_seed(0), + ) noise = noise.to(torch_device) time_step = torch.tensor([10] * noise.shape[0]).to(torch_device) @@ -154,7 +156,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800]) # fmt: on - self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) # TODO(Patrick) - Re-add this test after having cleaned up LDM @@ -191,7 +193,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): num_channels = 3 noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor(batch_size * [10]).to(torch_device) + time_step = torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device) return {"sample": noise, "timestep": time_step} @@ -291,3 +293,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): # fmt: on self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) + + def test_forward_with_norm_groups(self): + # not required for this model + pass diff --git a/tests/test_models_vae.py b/tests/test_models_vae.py index adf9767d2d..361eb618ab 100644 --- a/tests/test_models_vae.py +++ b/tests/test_models_vae.py @@ -18,6 +18,7 @@ import unittest import torch from diffusers import AutoencoderKL +from diffusers.modeling_utils import ModelMixin from diffusers.testing_utils import floats_tensor, torch_device from .test_modeling_common import ModelTesterMixin @@ -80,17 +81,38 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): model = model.to(torch_device) model.eval() - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) + # One-time warmup pass (see #372) + if torch_device == "mps" and isinstance(model, ModelMixin): + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + image = image.to(torch_device) + with torch.no_grad(): + _ = model(image, sample_posterior=True).sample + generator = torch.manual_seed(0) + else: + generator = torch.Generator(device=torch_device).manual_seed(0) - image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + image = torch.randn( + 1, + model.config.in_channels, + model.config.sample_size, + model.config.sample_size, + generator=torch.manual_seed(0), + ) image = image.to(torch_device) with torch.no_grad(): - output = model(image, sample_posterior=True).sample + output = model(image, sample_posterior=True, generator=generator).sample output_slice = output[0, -1, -3:, -3:].flatten().cpu() - # fmt: off - expected_output_slice = torch.tensor([-4.0078e-01, -3.8304e-04, -1.2681e-01, -1.1462e-01, 2.0095e-01, 1.0893e-01, -8.8248e-02, -3.0361e-01, -9.8646e-03]) - # fmt: on + + # Since the VAE Gaussian prior's generator is seeded on the appropriate device, + # the expected output slices are not the same for CPU and GPU. + if torch_device in ("mps", "cpu"): + expected_output_slice = torch.tensor( + [-0.1352, 0.0878, 0.0419, -0.0818, -0.1069, 0.0688, -0.1458, -0.4446, -0.0026] + ) + else: + expected_output_slice = torch.tensor( + [-0.2421, 0.4642, 0.2507, -0.0438, 0.0682, 0.3160, -0.2018, -0.0727, 0.2485] + ) + self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) diff --git a/tests/test_models_vq.py b/tests/test_models_vq.py index c0acceccb4..7cce0ed13e 100644 --- a/tests/test_models_vq.py +++ b/tests/test_models_vq.py @@ -85,10 +85,13 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) image = image.to(torch_device) with torch.no_grad(): + # Warmup pass when using mps (see #372) + if torch_device == "mps": + _ = model(image) output = model(image).sample output_slice = output[0, -1, -3:, -3:].flatten().cpu() # fmt: off expected_output_slice = torch.tensor([-0.0153, -0.4044, -0.1880, -0.5161, -0.2418, -0.4072, -0.1612, -0.0633, -0.0143]) # fmt: on - self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index a05d57a73d..102a55a93e 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -14,6 +14,7 @@ # limitations under the License. import gc +import os import random import tempfile import unittest @@ -22,7 +23,6 @@ import numpy as np import torch import PIL -from datasets import load_dataset from diffusers import ( AutoencoderKL, DDIMPipeline, @@ -40,13 +40,17 @@ from diffusers import ( ScoreSdeVeScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, + StableDiffusionOnnxPipeline, StableDiffusionPipeline, UNet2DConditionModel, UNet2DModel, VQModel, ) +from diffusers.modeling_utils import WEIGHTS_NAME from diffusers.pipeline_utils import DiffusionPipeline -from diffusers.testing_utils import floats_tensor, slow, torch_device +from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME +from diffusers.testing_utils import floats_tensor, load_image, slow, torch_device +from diffusers.utils import CONFIG_NAME from PIL import Image from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer @@ -167,7 +171,7 @@ class PipelineFastTests(unittest.TestCase): @property def dummy_safety_checker(self): def check(images, *args, **kwargs): - return images, False + return images, [False] * len(images) return check @@ -194,6 +198,10 @@ class PipelineFastTests(unittest.TestCase): ddpm.to(torch_device) ddpm.set_progress_bar_config(disable=None) + # Warmup pass when using mps (see #372) + if torch_device == "mps": + _ = ddpm(num_inference_steps=1) + generator = torch.manual_seed(0) image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images @@ -207,8 +215,9 @@ class PipelineFastTests(unittest.TestCase): expected_slice = np.array( [1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04] ) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + tolerance = 1e-2 if torch_device != "mps" else 3e-2 + assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance def test_pndm_cifar10(self): unet = self.dummy_uncond_unet @@ -244,6 +253,14 @@ class PipelineFastTests(unittest.TestCase): ldm.set_progress_bar_config(disable=None) prompt = "A painting of a squirrel eating a burger" + + # Warmup pass when using mps (see #372) + if torch_device == "mps": + generator = torch.manual_seed(0) + _ = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=1, output_type="numpy")[ + "sample" + ] + generator = torch.manual_seed(0) image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy")[ "sample" @@ -316,6 +333,7 @@ class PipelineFastTests(unittest.TestCase): assert image.shape == (1, 128, 128, 3) expected_slice = np.array([0.5112, 0.4692, 0.4715, 0.5206, 0.4894, 0.5114, 0.5096, 0.4932, 0.4755]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -449,17 +467,18 @@ class PipelineFastTests(unittest.TestCase): sde_ve.to(torch_device) sde_ve.set_progress_bar_config(disable=None) - torch.manual_seed(0) - image = sde_ve(num_inference_steps=2, output_type="numpy").images + generator = torch.manual_seed(0) + image = sde_ve(num_inference_steps=2, output_type="numpy", generator=generator).images - torch.manual_seed(0) - image_from_tuple = sde_ve(num_inference_steps=2, output_type="numpy", return_dict=False)[0] + generator = torch.manual_seed(0) + image_from_tuple = sde_ve(num_inference_steps=2, output_type="numpy", generator=generator, return_dict=False)[ + 0 + ] image_slice = image[0, -3:, -3:, -1] image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -473,6 +492,11 @@ class PipelineFastTests(unittest.TestCase): ldm.to(torch_device) ldm.set_progress_bar_config(disable=None) + # Warmup pass when using mps (see #372) + if torch_device == "mps": + generator = torch.manual_seed(0) + _ = ldm(generator=generator, num_inference_steps=1, output_type="numpy").images + generator = torch.manual_seed(0) image = ldm(generator=generator, num_inference_steps=2, output_type="numpy").images @@ -628,7 +652,7 @@ class PipelineFastTests(unittest.TestCase): bert = self.dummy_text_encoder tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - image = self.dummy_image.to(device).permute(0, 2, 3, 1)[0] + image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] init_image = Image.fromarray(np.uint8(image)).convert("RGB") mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) @@ -687,6 +711,34 @@ class PipelineTesterMixin(unittest.TestCase): gc.collect() torch.cuda.empty_cache() + def test_smart_download(self): + model_id = "hf-internal-testing/unet-pipeline-dummy" + with tempfile.TemporaryDirectory() as tmpdirname: + _ = DiffusionPipeline.from_pretrained(model_id, cache_dir=tmpdirname, force_download=True) + local_repo_name = "--".join(["models"] + model_id.split("/")) + snapshot_dir = os.path.join(tmpdirname, local_repo_name, "snapshots") + snapshot_dir = os.path.join(snapshot_dir, os.listdir(snapshot_dir)[0]) + + # inspect all downloaded files to make sure that everything is included + assert os.path.isfile(os.path.join(snapshot_dir, DiffusionPipeline.config_name)) + assert os.path.isfile(os.path.join(snapshot_dir, CONFIG_NAME)) + assert os.path.isfile(os.path.join(snapshot_dir, SCHEDULER_CONFIG_NAME)) + assert os.path.isfile(os.path.join(snapshot_dir, WEIGHTS_NAME)) + assert os.path.isfile(os.path.join(snapshot_dir, "scheduler", SCHEDULER_CONFIG_NAME)) + assert os.path.isfile(os.path.join(snapshot_dir, "unet", WEIGHTS_NAME)) + assert os.path.isfile(os.path.join(snapshot_dir, "unet", WEIGHTS_NAME)) + # let's make sure the super large numpy file: + # https://huggingface.co/hf-internal-testing/unet-pipeline-dummy/blob/main/big_array.npy + # is not downloaded, but all the expected ones + assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy")) + + @property + def dummy_safety_checker(self): + def check(images, *args, **kwargs): + return images, [False] * len(images) + + return check + def test_from_pretrained_save_pretrained(self): # 1. Load models model = UNet2DModel( @@ -710,8 +762,8 @@ class PipelineTesterMixin(unittest.TestCase): new_ddpm.to(torch_device) generator = torch.manual_seed(0) - image = ddpm(generator=generator, output_type="numpy").images + generator = generator.manual_seed(0) new_image = new_ddpm(generator=generator, output_type="numpy").images @@ -731,8 +783,8 @@ class PipelineTesterMixin(unittest.TestCase): ddpm_from_hub.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) - image = ddpm(generator=generator, output_type="numpy").images + generator = generator.manual_seed(0) new_image = ddpm_from_hub(generator=generator, output_type="numpy").images @@ -755,8 +807,8 @@ class PipelineTesterMixin(unittest.TestCase): ddpm_from_hub_custom_model.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) - image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy").images + generator = generator.manual_seed(0) new_image = ddpm_from_hub(generator=generator, output_type="numpy").images @@ -949,7 +1001,7 @@ class PipelineTesterMixin(unittest.TestCase): assert image.shape == (1, 512, 512, 3) expected_slice = np.array([0.9326, 0.923, 0.951, 0.9365, 0.9214, 0.951, 0.9365, 0.9414, 0.918]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow def test_score_sde_ve_pipeline(self): @@ -962,14 +1014,14 @@ class PipelineTesterMixin(unittest.TestCase): sde_ve.to(torch_device) sde_ve.set_progress_bar_config(disable=None) - torch.manual_seed(0) - image = sde_ve(num_inference_steps=300, output_type="numpy").images + generator = torch.manual_seed(0) + image = sde_ve(num_inference_steps=10, output_type="numpy", generator=generator).images image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 256, 256, 3) - expected_slice = np.array([0.64363, 0.5868, 0.3031, 0.2284, 0.7409, 0.3216, 0.25643, 0.6557, 0.2633]) + expected_slice = np.array([0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow @@ -1118,42 +1170,86 @@ class PipelineTesterMixin(unittest.TestCase): @slow @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") - def test_stable_diffusion_img2img_pipeline(self): - ds = load_dataset("hf-internal-testing/diffusers-images", split="train") - - init_image = ds[2]["image"].resize((768, 512)) - output_image = ds[0]["image"].resize((768, 512)) + def test_stable_diffusion_text2img_pipeline(self): + expected_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/text2img/astronaut_riding_a_horse.png" + ) + expected_image = np.array(expected_image, dtype=np.float32) / 255.0 model_id = "CompVis/stable-diffusion-v1-4" - pipe = StableDiffusionImg2ImgPipeline.from_pretrained( + pipe = StableDiffusionPipeline.from_pretrained( model_id, - revision="fp16", # fp16 to infer 768x512 images with 16GB of VRAM - torch_dtype=torch.float16, + safety_checker=self.dummy_safety_checker, use_auth_token=True, ) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "astronaut riding a horse" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe(prompt=prompt, strength=0.75, guidance_scale=7.5, generator=generator, output_type="np") + image = output.images[0] + + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 1e-2 + + @slow + @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") + def test_stable_diffusion_img2img_pipeline(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/sketch-mountains-input.jpg" + ) + expected_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/fantasy_landscape.png" + ) + init_image = init_image.resize((768, 512)) + expected_image = np.array(expected_image, dtype=np.float32) / 255.0 + + model_id = "CompVis/stable-diffusion-v1-4" + pipe = StableDiffusionImg2ImgPipeline.from_pretrained( + model_id, + safety_checker=self.dummy_safety_checker, + use_auth_token=True, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() prompt = "A fantasy landscape, trending on artstation" generator = torch.Generator(device=torch_device).manual_seed(0) - with torch.autocast("cuda"): - output = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5, generator=generator) + output = pipe( + prompt=prompt, + init_image=init_image, + strength=0.75, + guidance_scale=7.5, + generator=generator, + output_type="np", + ) image = output.images[0] - expected_array = np.array(output_image) / 255.0 - sampled_array = np.array(image) / 255.0 - - assert sampled_array.shape == (512, 768, 3) - assert np.max(np.abs(sampled_array - expected_array)) < 1e-4 + assert image.shape == (512, 768, 3) + # img2img is flaky across GPUs even in fp32, so using MAE here + assert np.abs(expected_image - image).mean() < 1e-2 @slow @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") def test_stable_diffusion_img2img_pipeline_k_lms(self): - ds = load_dataset("hf-internal-testing/diffusers-images", split="train") - - init_image = ds[2]["image"].resize((768, 512)) - output_image = ds[1]["image"].resize((768, 512)) + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/sketch-mountains-input.jpg" + ) + expected_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/fantasy_landscape_k_lms.png" + ) + init_image = init_image.resize((768, 512)) + expected_image = np.array(expected_image, dtype=np.float32) / 255.0 lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") @@ -1161,61 +1257,90 @@ class PipelineTesterMixin(unittest.TestCase): pipe = StableDiffusionImg2ImgPipeline.from_pretrained( model_id, scheduler=lms, - revision="fp16", # fp16 to infer 768x512 images with 16GB of VRAM - torch_dtype=torch.float16, + safety_checker=self.dummy_safety_checker, use_auth_token=True, ) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() prompt = "A fantasy landscape, trending on artstation" generator = torch.Generator(device=torch_device).manual_seed(0) - with torch.autocast("cuda"): - output = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5, generator=generator) + output = pipe( + prompt=prompt, + init_image=init_image, + strength=0.75, + guidance_scale=7.5, + generator=generator, + output_type="np", + ) image = output.images[0] - expected_array = np.array(output_image) / 255.0 - sampled_array = np.array(image) / 255.0 - - assert sampled_array.shape == (512, 768, 3) - assert np.max(np.abs(sampled_array - expected_array)) < 1e-4 + assert image.shape == (512, 768, 3) + # img2img is flaky across GPUs even in fp32, so using MAE here + assert np.abs(expected_image - image).mean() < 1e-2 @slow @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") def test_stable_diffusion_inpaint_pipeline(self): - ds = load_dataset("hf-internal-testing/diffusers-images", split="train") - - init_image = ds[3]["image"].resize((768, 512)) - mask_image = ds[4]["image"].resize((768, 512)) - output_image = ds[5]["image"].resize((768, 512)) + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" + ) + expected_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/red_cat_sitting_on_a_park_bench.png" + ) + expected_image = np.array(expected_image, dtype=np.float32) / 255.0 model_id = "CompVis/stable-diffusion-v1-4" pipe = StableDiffusionInpaintPipeline.from_pretrained( model_id, - revision="fp16", # fp16 to infer 768x512 images in 16GB of VRAM - torch_dtype=torch.float16, + safety_checker=self.dummy_safety_checker, use_auth_token=True, ) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() - prompt = "A red cat sitting on a parking bench" + prompt = "A red cat sitting on a park bench" generator = torch.Generator(device=torch_device).manual_seed(0) - with torch.autocast("cuda"): - output = pipe( - prompt=prompt, - init_image=init_image, - mask_image=mask_image, - strength=0.75, - guidance_scale=7.5, - generator=generator, - ) + output = pipe( + prompt=prompt, + init_image=init_image, + mask_image=mask_image, + strength=0.75, + guidance_scale=7.5, + generator=generator, + output_type="np", + ) image = output.images[0] - expected_array = np.array(output_image) / 255.0 - sampled_array = np.array(image) / 255.0 + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 1e-2 - assert sampled_array.shape == (512, 768, 3) - assert np.max(np.abs(sampled_array - expected_array)) < 1e-3 + @slow + def test_stable_diffusion_onnx(self): + from scripts.convert_stable_diffusion_checkpoint_to_onnx import convert_models + + with tempfile.TemporaryDirectory() as tmpdirname: + convert_models("CompVis/stable-diffusion-v1-4", tmpdirname, opset=14) + + sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(tmpdirname, provider="CUDAExecutionProvider") + + prompt = "A painting of a squirrel eating a burger" + np.random.seed(0) + output = sd_pipe([prompt], guidance_scale=6.0, num_inference_steps=20, output_type="np") + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.0385, 0.0252, 0.0234, 0.0287, 0.0358, 0.0287, 0.0276, 0.0235, 0.0010]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 3c2e786fc1..7377797beb 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -19,7 +19,7 @@ from typing import Dict, List, Tuple import numpy as np import torch -from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler, ScoreSdeVeScheduler +from diffusers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler, ScoreSdeVeScheduler torch.backends.cuda.matmul.allow_tf32 = False @@ -318,13 +318,14 @@ class DDPMSchedulerTest(SchedulerCommonTest): model = self.dummy_model() sample = self.dummy_sample_deter + generator = torch.manual_seed(0) for t in reversed(range(num_trained_timesteps)): # 1. predict noise residual residual = model(sample, t) # 2. predict previous mean of sample x_t-1 - pred_prev_sample = scheduler.step(residual, t, sample).prev_sample + pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample # if t > 0: # noise = self.dummy_sample_deter @@ -336,7 +337,7 @@ class DDPMSchedulerTest(SchedulerCommonTest): result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_sum.item() - 259.0883) < 1e-2 + assert abs(result_sum.item() - 258.9070) < 1e-2 assert abs(result_mean.item() - 0.3374) < 1e-3 @@ -356,10 +357,38 @@ class DDIMSchedulerTest(SchedulerCommonTest): config.update(**kwargs) return config + def full_loop(self, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + num_inference_steps, eta = 10, 0.0 + + model = self.dummy_model() + sample = self.dummy_sample_deter + + scheduler.set_timesteps(num_inference_steps) + + for t in scheduler.timesteps: + residual = model(sample, t) + sample = scheduler.step(residual, t, sample, eta).prev_sample + + return sample + def test_timesteps(self): for timesteps in [100, 500, 1000]: self.check_over_configs(num_train_timesteps=timesteps) + def test_steps_offset(self): + for steps_offset in [0, 1]: + self.check_over_configs(steps_offset=steps_offset) + + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(steps_offset=1) + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(5) + assert torch.equal(scheduler.timesteps, torch.tensor([801, 601, 401, 201, 1])) + def test_betas(self): for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): self.check_over_configs(beta_start=beta_start, beta_end=beta_end) @@ -378,7 +407,7 @@ class DDIMSchedulerTest(SchedulerCommonTest): def test_inference_steps(self): for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]): - self.check_over_forward(num_inference_steps=num_inference_steps) + self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) def test_eta(self): for t, eta in zip([1, 10, 49], [0.0, 0.5, 1.0]): @@ -397,20 +426,7 @@ class DDIMSchedulerTest(SchedulerCommonTest): assert torch.sum(torch.abs(scheduler._get_variance(999, 998) - 0.02)) < 1e-5 def test_full_loop_no_noise(self): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - - num_inference_steps, eta = 10, 0.0 - - model = self.dummy_model() - sample = self.dummy_sample_deter - - scheduler.set_timesteps(num_inference_steps) - for t in scheduler.timesteps: - residual = model(sample, t) - - sample = scheduler.step(residual, t, sample, eta).prev_sample + sample = self.full_loop() result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) @@ -418,6 +434,24 @@ class DDIMSchedulerTest(SchedulerCommonTest): assert abs(result_sum.item() - 172.0067) < 1e-2 assert abs(result_mean.item() - 0.223967) < 1e-3 + def test_full_loop_with_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 149.8295) < 1e-2 + assert abs(result_mean.item() - 0.1951) < 1e-3 + + def test_full_loop_with_no_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01) + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 149.0784) < 1e-2 + assert abs(result_mean.item() - 0.1941) < 1e-3 + class PNDMSchedulerTest(SchedulerCommonTest): scheduler_classes = (PNDMScheduler,) @@ -502,6 +536,26 @@ class PNDMSchedulerTest(SchedulerCommonTest): assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + def full_loop(self, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter + scheduler.set_timesteps(num_inference_steps) + + for i, t in enumerate(scheduler.prk_timesteps): + residual = model(sample, t) + sample = scheduler.step_prk(residual, t, sample).prev_sample + + for i, t in enumerate(scheduler.plms_timesteps): + residual = model(sample, t) + sample = scheduler.step_plms(residual, t, sample).prev_sample + + return sample + def test_pytorch_equal_numpy(self): kwargs = dict(self.forward_default_kwargs) num_inference_steps = kwargs.pop("num_inference_steps", None) @@ -605,8 +659,23 @@ class PNDMSchedulerTest(SchedulerCommonTest): for timesteps in [100, 1000]: self.check_over_configs(num_train_timesteps=timesteps) + def test_steps_offset(self): + for steps_offset in [0, 1]: + self.check_over_configs(steps_offset=steps_offset) + + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(steps_offset=1) + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(10) + assert torch.equal( + scheduler.timesteps, + torch.tensor( + [901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1] + ), + ) + def test_betas(self): - for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]): + for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]): self.check_over_configs(beta_start=beta_start, beta_end=beta_end) def test_schedules(self): @@ -619,7 +688,24 @@ class PNDMSchedulerTest(SchedulerCommonTest): def test_inference_steps(self): for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]): - self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) + self.check_over_forward(num_inference_steps=num_inference_steps) + + def test_pow_of_3_inference_steps(self): + # earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3 + num_inference_steps = 27 + + for scheduler_class in self.scheduler_classes: + sample = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(num_inference_steps) + + # before power of 3 fix, would error on first step, so we only need to do two + for i, t in enumerate(scheduler.prk_timesteps[:2]): + sample = scheduler.step_prk(residual, t, sample).prev_sample def test_inference_plms_no_past_residuals(self): with self.assertRaises(ValueError): @@ -630,34 +716,36 @@ class PNDMSchedulerTest(SchedulerCommonTest): scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample).prev_sample def test_full_loop_no_noise(self): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - - num_inference_steps = 10 - model = self.dummy_model() - sample = self.dummy_sample_deter - scheduler.set_timesteps(num_inference_steps) - - for i, t in enumerate(scheduler.prk_timesteps): - residual = model(sample, t) - sample = scheduler.step_prk(residual, i, sample).prev_sample - - for i, t in enumerate(scheduler.plms_timesteps): - residual = model(sample, t) - sample = scheduler.step_plms(residual, i, sample).prev_sample - + sample = self.full_loop() result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_sum.item() - 428.8788) < 1e-2 - assert abs(result_mean.item() - 0.5584) < 1e-3 + assert abs(result_sum.item() - 198.1318) < 1e-2 + assert abs(result_mean.item() - 0.2580) < 1e-3 + + def test_full_loop_with_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 230.0399) < 1e-2 + assert abs(result_mean.item() - 0.2995) < 1e-3 + + def test_full_loop_with_no_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01) + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 186.9482) < 1e-2 + assert abs(result_mean.item() - 0.2434) < 1e-3 class ScoreSdeVeSchedulerTest(unittest.TestCase): # TODO adapt with class SchedulerCommonTest (scheduler needs Numpy Integration) scheduler_classes = (ScoreSdeVeScheduler,) - forward_default_kwargs = (("seed", 0),) + forward_default_kwargs = () @property def dummy_sample(self): @@ -718,13 +806,19 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): scheduler.save_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname) - output = scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample - new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample + output = scheduler.step_pred( + residual, time_step, sample, generator=torch.manual_seed(0), **kwargs + ).prev_sample + new_output = new_scheduler.step_pred( + residual, time_step, sample, generator=torch.manual_seed(0), **kwargs + ).prev_sample assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - output = scheduler.step_correct(residual, sample, **kwargs).prev_sample - new_output = new_scheduler.step_correct(residual, sample, **kwargs).prev_sample + output = scheduler.step_correct(residual, sample, generator=torch.manual_seed(0), **kwargs).prev_sample + new_output = new_scheduler.step_correct( + residual, sample, generator=torch.manual_seed(0), **kwargs + ).prev_sample assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical" @@ -743,13 +837,19 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): scheduler.save_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname) - output = scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample - new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample + output = scheduler.step_pred( + residual, time_step, sample, generator=torch.manual_seed(0), **kwargs + ).prev_sample + new_output = new_scheduler.step_pred( + residual, time_step, sample, generator=torch.manual_seed(0), **kwargs + ).prev_sample assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - output = scheduler.step_correct(residual, sample, **kwargs).prev_sample - new_output = new_scheduler.step_correct(residual, sample, **kwargs).prev_sample + output = scheduler.step_correct(residual, sample, generator=torch.manual_seed(0), **kwargs).prev_sample + new_output = new_scheduler.step_correct( + residual, sample, generator=torch.manual_seed(0), **kwargs + ).prev_sample assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical" @@ -779,26 +879,27 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): scheduler.set_sigmas(num_inference_steps) scheduler.set_timesteps(num_inference_steps) + generator = torch.manual_seed(0) for i, t in enumerate(scheduler.timesteps): sigma_t = scheduler.sigmas[i] - for _ in range(scheduler.correct_steps): + for _ in range(scheduler.config.correct_steps): with torch.no_grad(): model_output = model(sample, sigma_t) - sample = scheduler.step_correct(model_output, sample, **kwargs).prev_sample + sample = scheduler.step_correct(model_output, sample, generator=generator, **kwargs).prev_sample with torch.no_grad(): model_output = model(sample, sigma_t) - output = scheduler.step_pred(model_output, t, sample, **kwargs) + output = scheduler.step_pred(model_output, t, sample, generator=generator, **kwargs) sample, _ = output.prev_sample, output.prev_sample_mean result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_sum.item() - 14379591680.0) < 1e-2 - assert abs(result_mean.item() - 18723426.0) < 1e-3 + assert np.isclose(result_sum.item(), 14372758528.0) + assert np.isclose(result_mean.item(), 18714530.0) def test_step_shape(self): kwargs = dict(self.forward_default_kwargs) @@ -817,8 +918,88 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): kwargs["num_inference_steps"] = num_inference_steps - output_0 = scheduler.step_pred(residual, 0, sample, **kwargs).prev_sample - output_1 = scheduler.step_pred(residual, 1, sample, **kwargs).prev_sample + output_0 = scheduler.step_pred(residual, 0, sample, generator=torch.manual_seed(0), **kwargs).prev_sample + output_1 = scheduler.step_pred(residual, 1, sample, generator=torch.manual_seed(0), **kwargs).prev_sample self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, output_1.shape) + + +class LMSDiscreteSchedulerTest(SchedulerCommonTest): + scheduler_classes = (LMSDiscreteScheduler,) + num_inference_steps = 10 + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1100, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + "trained_betas": None, + "tensor_format": "pt", + } + + config.update(**kwargs) + return config + + def test_timesteps(self): + for timesteps in [10, 50, 100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_betas(self): + for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "scaled_linear"]: + self.check_over_configs(beta_schedule=schedule) + + def test_time_indices(self): + for t in [0, 500, 800]: + self.check_over_forward(time_step=t) + + def test_pytorch_equal_numpy(self): + for scheduler_class in self.scheduler_classes: + sample_pt = self.dummy_sample + residual_pt = 0.1 * sample_pt + + sample = sample_pt.numpy() + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler_config["tensor_format"] = "np" + scheduler = scheduler_class(**scheduler_config) + + scheduler_config["tensor_format"] = "pt" + scheduler_pt = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + scheduler_pt.set_timesteps(self.num_inference_steps) + + output = scheduler.step(residual, 1, sample).prev_sample + output_pt = scheduler_pt.step(residual_pt, 1, sample_pt).prev_sample + assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" + + def test_full_loop_no_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.sigmas[0] + + for i, t in enumerate(scheduler.timesteps): + sample = sample / ((scheduler.sigmas[i] ** 2 + 1) ** 0.5) + + model_output = model(sample, t) + + output = scheduler.step(model_output, i, sample) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 1006.388) < 1e-2 + assert abs(result_mean.item() - 1.31) < 1e-3 diff --git a/tests/test_training.py b/tests/test_training.py index 27caf03365..519c5ab9e7 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -52,7 +52,7 @@ class TrainingTests(unittest.TestCase): tensor_format="pt", ) - assert ddpm_scheduler.num_train_timesteps == ddim_scheduler.num_train_timesteps + assert ddpm_scheduler.config.num_train_timesteps == ddim_scheduler.config.num_train_timesteps # shared batches for DDPM and DDIM set_seed(0) diff --git a/utils/check_inits.py b/utils/check_inits.py index 98d4caf010..c5e25182a4 100644 --- a/utils/check_inits.py +++ b/utils/check_inits.py @@ -288,7 +288,7 @@ def check_submodules(): if len(module_not_registered) > 0: list_of_modules = "\n".join(f"- {module}" for module in module_not_registered) raise ValueError( - "The following submodules are not properly registed in the main init of Transformers:\n" + "The following submodules are not properly registered in the main init of Transformers:\n" f"{list_of_modules}\n" "Make sure they appear somewhere in the keys of `_import_structure` with an empty list as value." ) diff --git a/utils/check_table.py b/utils/check_table.py index 6c74308c2e..28c6ea8891 100644 --- a/utils/check_table.py +++ b/utils/check_table.py @@ -53,7 +53,7 @@ def _find_text_in_file(filename, start_prompt, end_prompt): return "".join(lines[start_index:end_index]), start_index, end_index, lines -# Add here suffixes that are used to identify models, seperated by | +# Add here suffixes that are used to identify models, separated by | ALLOWED_MODEL_SUFFIXES = "Model|Encoder|Decoder|ForConditionalGeneration" # Regexes that match TF/Flax/PT model names. _re_tf_models = re.compile(r"TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)") @@ -88,11 +88,11 @@ def _center_text(text, width): def get_model_table_from_auto_modules(): """Generates an up-to-date model table from the content of the auto modules.""" # Dictionary model names to config. - config_maping_names = diffusers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES + config_mapping_names = diffusers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES model_name_to_config = { - name: config_maping_names[code] + name: config_mapping_names[code] for code, name in diffusers_module.MODEL_NAMES_MAPPING.items() - if code in config_maping_names + if code in config_mapping_names } model_name_to_prefix = {name: config.replace("ConfigMixin", "") for name, config in model_name_to_config.items()} diff --git a/utils/check_tf_ops.py b/utils/check_tf_ops.py index f6c2b8bae4..a3b9593bb2 100644 --- a/utils/check_tf_ops.py +++ b/utils/check_tf_ops.py @@ -41,7 +41,7 @@ INTERNAL_OPS = [ ] -def onnx_compliancy(saved_model_path, strict, opset): +def onnx_compliance(saved_model_path, strict, opset): saved_model = SavedModel() onnx_ops = [] @@ -98,4 +98,4 @@ if __name__ == "__main__": args = parser.parse_args() if args.framework == "onnx": - onnx_compliancy(args.saved_model_path, args.strict, args.opset) + onnx_compliance(args.saved_model_path, args.strict, args.opset) diff --git a/utils/custom_init_isort.py b/utils/custom_init_isort.py index 6501654872..e1e079a99c 100644 --- a/utils/custom_init_isort.py +++ b/utils/custom_init_isort.py @@ -178,7 +178,7 @@ def sort_imports(file, check_only=True): code, start_prompt="_import_structure = {", end_prompt="if TYPE_CHECKING:" ) - # We ignore block 0 (everything untils start_prompt) and the last block (everything after end_prompt). + # We ignore block 0 (everything until start_prompt) and the last block (everything after end_prompt). for block_idx in range(1, len(main_blocks) - 1): # Check if the block contains some `_import_structure`s thingy to sort. block = main_blocks[block_idx] @@ -202,7 +202,7 @@ def sort_imports(file, check_only=True): internal_blocks = split_code_in_indented_blocks(internal_block_code, indent_level=indent) # We have two categories of import key: list or _import_structu[key].append/extend pattern = _re_direct_key if "_import_structure" in block_lines[0] else _re_indirect_key - # Grab the keys, but there is a trap: some lines are empty or jsut comments. + # Grab the keys, but there is a trap: some lines are empty or just comments. keys = [(pattern.search(b).groups()[0] if pattern.search(b) is not None else None) for b in internal_blocks] # We only sort the lines with a key. keys_to_sort = [(i, key) for i, key in enumerate(keys) if key is not None] diff --git a/utils/get_modified_files.py b/utils/get_modified_files.py new file mode 100644 index 0000000000..44c60e96ab --- /dev/null +++ b/utils/get_modified_files.py @@ -0,0 +1,34 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# this script reports modified .py files under the desired list of top-level sub-dirs passed as a list of arguments, e.g.: +# python ./utils/get_modified_files.py utils src tests examples +# +# it uses git to find the forking point and which files were modified - i.e. files not under git won't be considered +# since the output of this script is fed into Makefile commands it doesn't print a newline after the results + +import re +import subprocess +import sys + + +fork_point_sha = subprocess.check_output("git merge-base main HEAD".split()).decode("utf-8") +modified_files = subprocess.check_output(f"git diff --name-only {fork_point_sha}".split()).decode("utf-8").split() + +joined_dirs = "|".join(sys.argv[1:]) +regex = re.compile(rf"^({joined_dirs}).*?\.py$") + +relevant_modified_files = [x for x in modified_files if regex.match(x)] +print(" ".join(relevant_modified_files), end="") diff --git a/utils/stale.py b/utils/stale.py new file mode 100644 index 0000000000..4162875a62 --- /dev/null +++ b/utils/stale.py @@ -0,0 +1,61 @@ +# Copyright 2022 The HuggingFace Team, the AllenNLP library authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Script to close stale issue. Taken in part from the AllenNLP repository. +https://github.com/allenai/allennlp. +""" +import os +from datetime import datetime as dt + +from github import Github + + +LABELS_TO_EXEMPT = [ + "good first issue", + "good second issue", + "good difficult issue", + "enhancement", + "new pipeline/model", + "new scheduler", + "wip", +] + + +def main(): + g = Github(os.environ["GITHUB_TOKEN"]) + repo = g.get_repo("huggingface/diffusers") + open_issues = repo.get_issues(state="open") + + for issue in open_issues: + comments = sorted([comment for comment in issue.get_comments()], key=lambda i: i.created_at, reverse=True) + last_comment = comments[0] if len(comments) > 0 else None + if ( + last_comment is not None + and last_comment.user.login != "github-actions[bot]" + and (dt.utcnow() - issue.updated_at).days > 23 + and (dt.utcnow() - issue.created_at).days >= 30 + and not any(label.name.lower() in LABELS_TO_EXEMPT for label in issue.get_labels()) + ): + issue.create_comment( + "This issue has been automatically marked as stale because it has not had " + "recent activity. If you think this still needs to be addressed " + "please comment on this thread.\n\nPlease note that issues that do not follow the " + "[contributing guidelines](https://github.com/huggingface/diffusers/blob/main/CONTRIBUTING.md) " + "are likely to be ignored." + ) + issue.edit(labels=["stale"]) + + +if __name__ == "__main__": + main()