diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index 1208a0d25f..f63d4ffda4 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -41,4 +41,15 @@ jobs: - name: Run all non-slow selected tests on CPU run: | - python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile -s tests/ + python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=tests_torch_cpu tests/ + + - name: Failure short reports + if: ${{ failure() }} + run: cat reports/tests_torch_cpu_failures_short.txt + + - name: Test suite reports artifacts + if: ${{ always() }} + uses: actions/upload-artifact@v2 + with: + name: pr_torch_test_reports + path: reports diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index 58ecc0d3c2..3db6814e07 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -49,4 +49,66 @@ jobs: env: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} run: | - python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s tests/ + python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=tests_torch_gpu tests/ + + - name: Failure short reports + if: ${{ failure() }} + run: cat reports/tests_torch_gpu_failures_short.txt + + - name: Test suite reports artifacts + if: ${{ always() }} + uses: actions/upload-artifact@v2 + with: + name: torch_test_reports + path: reports + + + + run_examples_single_gpu: + name: Examples tests + strategy: + fail-fast: false + matrix: + machine_type: [ single-gpu ] + runs-on: [ self-hosted, docker-gpu, '${{ matrix.machine_type }}' ] + container: + image: nvcr.io/nvidia/pytorch:22.07-py3 + options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ + + steps: + - name: Checkout diffusers + uses: actions/checkout@v3 + with: + fetch-depth: 2 + + - name: NVIDIA-SMI + run: | + nvidia-smi + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip uninstall -y torch torchvision torchtext + python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 + python -m pip install -e .[quality,test,training] + + - name: Environment + run: | + python utils/print_env.py + + - name: Run example tests on GPU + env: + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + run: | + python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_gpu examples/ + + - name: Failure short reports + if: ${{ failure() }} + run: cat reports/examples_torch_gpu_failures_short.txt + + - name: Test suite reports artifacts + if: ${{ always() }} + uses: actions/upload-artifact@v2 + with: + name: examples_test_reports + path: reports diff --git a/_typos.toml b/_typos.toml index 4025388915..551099f981 100644 --- a/_typos.toml +++ b/_typos.toml @@ -4,8 +4,9 @@ [default.extend-identifiers] [default.extend-words] -NIN_="NIN" # NIN is used in scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py +NIN="NIN" # NIN is used in scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py nd="np" # nd may be np (numpy) +parms="parms" # parms is used in scripts/convert_original_stable_diffusion_to_diffusers.py [files] diff --git a/docs/source/api/models.mdx b/docs/source/api/models.mdx index 525548e7c3..c92fdccb83 100644 --- a/docs/source/api/models.mdx +++ b/docs/source/api/models.mdx @@ -45,3 +45,21 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## AutoencoderKL [[autodoc]] AutoencoderKL + +## FlaxModelMixin +[[autodoc]] FlaxModelMixin + +## FlaxUNet2DConditionOutput +[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput + +## FlaxUNet2DConditionModel +[[autodoc]] FlaxUNet2DConditionModel + +## FlaxDecoderOutput +[[autodoc]] models.vae_flax.FlaxDecoderOutput + +## FlaxAutoencoderKLOutput +[[autodoc]] models.vae_flax.FlaxAutoencoderKLOutput + +## FlaxAutoencoderKL +[[autodoc]] FlaxAutoencoderKL diff --git a/docs/source/api/pipelines/overview.mdx b/docs/source/api/pipelines/overview.mdx index 881e5cdbd9..a9b1bb2821 100644 --- a/docs/source/api/pipelines/overview.mdx +++ b/docs/source/api/pipelines/overview.mdx @@ -53,7 +53,7 @@ available a colab notebook to directly try them out. | [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) | [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) | [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) -| [stochatic_karras_ve](./stochatic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | +| [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | **Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers. @@ -85,7 +85,7 @@ not be used for training. If you want to store the gradients during the forward We are more than happy about any contribution to the officially supported pipelines 🤗. We aspire all of our pipelines to be **self-contained**, **easy-to-tweak**, **beginner-friendly** and for **one-purpose-only**. -- **Self-contained**: A pipeline shall be as self-contained as possible. More specifically, this means that all functionality should be either directly defined in the pipeline file iteslf, should be inherited from (and only from) the [`DiffusionPipeline` class](.../diffusion_pipeline) or be directly attached to the model and scheduler components of the pipeline. +- **Self-contained**: A pipeline shall be as self-contained as possible. More specifically, this means that all functionality should be either directly defined in the pipeline file itself, should be inherited from (and only from) the [`DiffusionPipeline` class](.../diffusion_pipeline) or be directly attached to the model and scheduler components of the pipeline. - **Easy-to-use**: Pipelines should be extremely easy to use - one should be able to load the pipeline and use it for its designated task, *e.g.* text-to-image generation, in just a couple of lines of code. Most logic including pre-processing, an unrolled diffusion loop, and post-processing should all happen inside the `__call__` method. diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 453b4a5b78..434c58cc8b 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -44,6 +44,6 @@ available a colab notebook to directly try them out. | [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) | [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) | [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) -| [stochatic_karras_ve](./api/pipelines/stochatic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | +| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | **Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers. diff --git a/docs/source/using-diffusers/configuration.mdx b/docs/source/using-diffusers/configuration.mdx index 044f3937b9..a64cce0606 100644 --- a/docs/source/using-diffusers/configuration.mdx +++ b/docs/source/using-diffusers/configuration.mdx @@ -27,6 +27,6 @@ pip install diffusers ### Schedulers -### Pipeliens +### Pipelines diff --git a/docs/source/using-diffusers/unconditional_image_generation.mdx b/docs/source/using-diffusers/unconditional_image_generation.mdx index 8f5449f8fb..ba119defb8 100644 --- a/docs/source/using-diffusers/unconditional_image_generation.mdx +++ b/docs/source/using-diffusers/unconditional_image_generation.mdx @@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License. -# Unonditional Image Generation +# Unconditional Image Generation The [`DiffusionPipeline`] is the easiest way to use a pre-trained diffusion system for inference diff --git a/examples/community/README.md b/examples/community/README.md index ce6f57b018..aada72163b 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -2,5 +2,6 @@ **Community** examples consist of both inference and training examples that have been added by the community. -| Example | Description | Author | | -|:----------|:-------------|:-------------|------:| +| Example | Description | Author | Colab | +|:----------|:----------------------|:-----------------|----------:| +| CLIP Guided Stable Diffusion | Doing CLIP guidance for text to image generation with Stable Diffusion| [Suraj Patil](https://github.com/patil-suraj/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb) | diff --git a/examples/community/clip_guided_stable_diffusion.py b/examples/community/clip_guided_stable_diffusion.py new file mode 100644 index 0000000000..f781757359 --- /dev/null +++ b/examples/community/clip_guided_stable_diffusion.py @@ -0,0 +1,321 @@ +import inspect +from typing import List, Optional, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from diffusers import AutoencoderKL, DiffusionPipeline, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput +from torchvision import transforms +from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer + + +class MakeCutouts(nn.Module): + def __init__(self, cut_size, cut_power=1.0): + super().__init__() + + self.cut_size = cut_size + self.cut_power = cut_power + + def forward(self, pixel_values, num_cutouts): + sideY, sideX = pixel_values.shape[2:4] + max_size = min(sideX, sideY) + min_size = min(sideX, sideY, self.cut_size) + cutouts = [] + for _ in range(num_cutouts): + size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size) + offsetx = torch.randint(0, sideX - size + 1, ()) + offsety = torch.randint(0, sideY - size + 1, ()) + cutout = pixel_values[:, :, offsety : offsety + size, offsetx : offsetx + size] + cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) + return torch.cat(cutouts) + + +def spherical_dist_loss(x, y): + x = F.normalize(x, dim=-1) + y = F.normalize(y, dim=-1) + return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) + + +def set_requires_grad(model, value): + for param in model.parameters(): + param.requires_grad = value + + +class CLIPGuidedStableDiffusion(DiffusionPipeline): + """CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000 + - https://github.com/Jack000/glid-3-xl + - https://github.dev/crowsonkb/k-diffusion + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + clip_model: CLIPModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[PNDMScheduler, LMSDiscreteScheduler], + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + scheduler = scheduler.set_format("pt") + self.register_modules( + vae=vae, + text_encoder=text_encoder, + clip_model=clip_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + feature_extractor=feature_extractor, + ) + + self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std) + self.make_cutouts = MakeCutouts(feature_extractor.size) + + set_requires_grad(self.text_encoder, False) + set_requires_grad(self.clip_model, False) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + self.enable_attention_slicing(None) + + def freeze_vae(self): + set_requires_grad(self.vae, False) + + def unfreeze_vae(self): + set_requires_grad(self.vae, True) + + def freeze_unet(self): + set_requires_grad(self.unet, False) + + def unfreeze_unet(self): + set_requires_grad(self.unet, True) + + @torch.enable_grad() + def cond_fn( + self, + latents, + timestep, + index, + text_embeddings, + noise_pred_original, + text_embeddings_clip, + clip_guidance_scale, + num_cutouts, + use_cutouts=True, + ): + latents = latents.detach().requires_grad_() + + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[index] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latents / ((sigma**2 + 1) ** 0.5) + else: + latent_model_input = latents + + # predict the noise residual + noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample + + if isinstance(self.scheduler, PNDMScheduler): + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] + beta_prod_t = 1 - alpha_prod_t + # compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) + + fac = torch.sqrt(beta_prod_t) + sample = pred_original_sample * (fac) + latents * (1 - fac) + elif isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[index] + sample = latents - sigma * noise_pred + else: + raise ValueError(f"scheduler type {type(self.scheduler)} not supported") + + sample = 1 / 0.18215 * sample + image = self.vae.decode(sample).sample + image = (image / 2 + 0.5).clamp(0, 1) + + if use_cutouts: + image = self.make_cutouts(image, num_cutouts) + else: + image = transforms.Resize(self.feature_extractor.size)(image) + image = self.normalize(image) + + image_embeddings_clip = self.clip_model.get_image_features(image).float() + image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) + + if use_cutouts: + dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip) + dists = dists.view([num_cutouts, sample.shape[0], -1]) + loss = dists.sum(2).mean(0).sum() * clip_guidance_scale + else: + loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale + + grads = -torch.autograd.grad(loss, latents)[0] + + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents.detach() + grads * (sigma**2) + noise_pred = noise_pred_original + else: + noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads + return noise_pred, latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + clip_guidance_scale: Optional[float] = 100, + clip_prompt: Optional[Union[str, List[str]]] = None, + num_cutouts: Optional[int] = 4, + use_cutouts: Optional[bool] = True, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get prompt text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + + if clip_guidance_scale > 0: + if clip_prompt is not None: + clip_text_input = self.tokenizer( + clip_prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ).input_ids.to(self.device) + else: + clip_text_input = text_input.input_ids.to(self.device) + text_embeddings_clip = self.clip_model.get_text_features(clip_text_input) + text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + max_length = text_input.input_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_device = "cpu" if self.device.type == "mps" else self.device + latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + if latents is None: + latents = torch.randn( + latents_shape, + generator=generator, + device=latents_device, + ) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) + + # set timesteps + accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) + extra_set_kwargs = {} + if accepts_offset: + extra_set_kwargs["offset"] = 1 + + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + + # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents * self.scheduler.sigmas[0] + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[i] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + + # # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform classifier free guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # perform clip guidance + if clip_guidance_scale > 0: + text_embeddings_for_guidance = ( + text_embeddings.chunk(2)[0] if do_classifier_free_guidance else text_embeddings + ) + noise_pred, latents = self.cond_fn( + latents, + t, + i, + text_embeddings_for_guidance, + noise_pred, + text_embeddings_clip, + clip_guidance_scale, + num_cutouts, + use_cutouts, + ) + + # compute the previous noisy sample x_t -> x_t-1 + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = self.scheduler.step(noise_pred, i, latents).prev_sample + else: + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, None) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None) diff --git a/examples/conftest.py b/examples/conftest.py new file mode 100644 index 0000000000..a72bc85310 --- /dev/null +++ b/examples/conftest.py @@ -0,0 +1,45 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# tests directory-specific settings - this file is run automatically +# by pytest before any tests are run + +import sys +import warnings +from os.path import abspath, dirname, join + + +# allow having multiple repository checkouts and not needing to remember to rerun +# 'pip install -e .[dev]' when switching between checkouts and running tests. +git_repo_path = abspath(join(dirname(dirname(dirname(__file__))), "src")) +sys.path.insert(1, git_repo_path) + + +# silence FutureWarning warnings in tests since often we can't act on them until +# they become normal warnings - i.e. the tests still need to test the current functionality +warnings.simplefilter(action="ignore", category=FutureWarning) + + +def pytest_addoption(parser): + from diffusers.testing_utils import pytest_addoption_shared + + pytest_addoption_shared(parser) + + +def pytest_terminal_summary(terminalreporter): + from diffusers.testing_utils import pytest_terminal_summary_main + + make_reports = terminalreporter.config.getoption("--make-reports") + if make_reports: + pytest_terminal_summary_main(terminalreporter, id=make_reports) diff --git a/examples/test_examples.py b/examples/test_examples.py new file mode 100644 index 0000000000..0099d17e63 --- /dev/null +++ b/examples/test_examples.py @@ -0,0 +1,124 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc.. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import os +import shutil +import subprocess +import sys +import tempfile +import unittest +from typing import List + +from accelerate.utils import write_basic_config +from diffusers.testing_utils import slow + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() + + +# These utils relate to ensuring the right error message is received when running scripts +class SubprocessCallException(Exception): + pass + + +def run_command(command: List[str], return_stdout=False): + """ + Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture + if an error occured while running `command` + """ + try: + output = subprocess.check_output(command, stderr=subprocess.STDOUT) + if return_stdout: + if hasattr(output, "decode"): + output = output.decode("utf-8") + return output + except subprocess.CalledProcessError as e: + raise SubprocessCallException( + f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}" + ) from e + + +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class ExamplesTestsAccelerate(unittest.TestCase): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._tmpdir = tempfile.mkdtemp() + cls.configPath = os.path.join(cls._tmpdir, "default_config.yml") + + write_basic_config(save_location=cls.configPath) + cls._launch_args = ["accelerate", "launch", "--config_file", cls.configPath] + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + shutil.rmtree(cls._tmpdir) + + @slow + def test_train_unconditional(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/unconditional_image_generation/train_unconditional.py + --dataset_name huggan/few-shot-aurora + --resolution 64 + --output_dir {tmpdir} + --train_batch_size 4 + --num_epochs 1 + --gradient_accumulation_steps 1 + --learning_rate 1e-3 + --lr_warmup_steps 5 + --mixed_precision fp16 + """.split() + + run_command(self._launch_args + test_args, return_stdout=True) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + # logging test + self.assertTrue(len(os.listdir(os.path.join(tmpdir, "logs", "train_unconditional"))) > 0) + + @slow + def test_textual_inversion(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/textual_inversion/textual_inversion.py + --pretrained_model_name_or_path CompVis/stable-diffusion-v1-4 + --use_auth_token + --train_data_dir docs/source/imgs + --learnable_property object + --placeholder_token + --initializer_token toy + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 2 + --max_train_steps 10 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --mixed_precision fp16 + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.bin"))) diff --git a/scripts/convert_stable_diffusion_checkpoint_to_onnx.py b/scripts/convert_stable_diffusion_checkpoint_to_onnx.py index 0e4550b788..beeacfe377 100644 --- a/scripts/convert_stable_diffusion_checkpoint_to_onnx.py +++ b/scripts/convert_stable_diffusion_checkpoint_to_onnx.py @@ -13,11 +13,14 @@ # limitations under the License. import argparse +import os +import shutil from pathlib import Path import torch from torch.onnx import export +import onnx from diffusers import StableDiffusionOnnxPipeline, StableDiffusionPipeline from diffusers.onnx_utils import OnnxRuntimeModel from packaging import version @@ -92,10 +95,11 @@ def convert_models(model_path: str, output_path: str, opset: int): ) # UNET + unet_path = output_path / "unet" / "model.onnx" onnx_export( pipeline.unet, model_args=(torch.randn(2, 4, 64, 64), torch.LongTensor([0, 1]), torch.randn(2, 77, 768), False), - output_path=output_path / "unet" / "model.onnx", + output_path=unet_path, ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"], output_names=["out_sample"], # has to be different from "sample" for correct tracing dynamic_axes={ @@ -106,6 +110,21 @@ def convert_models(model_path: str, output_path: str, opset: int): opset=opset, use_external_data_format=True, # UNet is > 2GB, so the weights need to be split ) + unet_model_path = str(unet_path.absolute().as_posix()) + unet_dir = os.path.dirname(unet_model_path) + unet = onnx.load(unet_model_path) + # clean up existing tensor files + shutil.rmtree(unet_dir) + os.mkdir(unet_dir) + # collate external tensor files into one + onnx.save_model( + unet, + unet_model_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location="weights.pb", + convert_attribute=False, + ) # VAE ENCODER vae_encoder = pipeline.vae diff --git a/setup.py b/setup.py index a9022ea7b6..20c9ea61f5 100644 --- a/setup.py +++ b/setup.py @@ -77,7 +77,7 @@ from setuptools import find_packages, setup # 1. all dependencies should be listed here with their version requirements if any # 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py _deps = [ - "Pillow", + "Pillow<10.0", # keep the PIL.Image.Resampling deprecation away "accelerate>=0.11.0", "black==22.8", "datasets", @@ -90,8 +90,9 @@ _deps = [ "isort>=5.5.4", "jax>=0.2.8,!=0.3.2,<=0.3.6", "jaxlib>=0.1.65,<=0.3.6", - "modelcards==0.1.4", + "modelcards>=0.1.4", "numpy", + "onnxruntime-gpu", "pytest", "pytest-timeout", "pytest-xdist", @@ -100,6 +101,7 @@ _deps = [ "requests", "tensorboard", "torch>=1.4", + "torchvision", "transformers>=4.21.0", ] @@ -171,10 +173,19 @@ extras = {} extras = {} -extras["quality"] = ["black==22.8", "isort>=5.5.4", "flake8>=3.8.3", "hf-doc-builder"] -extras["docs"] = ["hf-doc-builder"] -extras["training"] = ["accelerate", "datasets", "tensorboard", "modelcards"] -extras["test"] = ["datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "transformers"] +extras["quality"] = deps_list("black", "isort", "flake8", "hf-doc-builder") +extras["docs"] = deps_list("hf-doc-builder") +extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards") +extras["test"] = deps_list( + "datasets", + "onnxruntime-gpu", + "pytest", + "pytest-timeout", + "pytest-xdist", + "scipy", + "torchvision", + "transformers" +) extras["torch"] = deps_list("torch") if os.name == "nt": # windows diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index ecf4fe5fef..acdddaac4d 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -65,6 +65,8 @@ else: if is_flax_available(): from .modeling_flax_utils import FlaxModelMixin from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel + from .models.vae_flax import FlaxAutoencoderKL + from .pipeline_flax_utils import FlaxDiffusionPipeline from .schedulers import ( FlaxDDIMScheduler, FlaxDDPMScheduler, @@ -75,3 +77,8 @@ if is_flax_available(): ) else: from .utils.dummy_flax_objects import * # noqa F403 + +if is_flax_available() and is_transformers_available(): + from .pipelines import FlaxStableDiffusionPipeline +else: + from .utils.dummy_flax_and_transformers_objects import * # noqa F403 diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index f5e5d36ffd..19f58fd816 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -154,15 +154,25 @@ class ConfigMixin: """ config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) - init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) + # Allow dtype to be specified on initialization + if "dtype" in unused_kwargs: + init_dict["dtype"] = unused_kwargs.pop("dtype") + + # Return model and optionally state and/or unused_kwargs model = cls(**init_dict) + return_tuple = (model,) + + # Flax schedulers have a state, so return it. + if cls.__name__.startswith("Flax") and hasattr(model, "create_state") and getattr(model, "has_state", False): + state = model.create_state() + return_tuple += (state,) if return_unused_kwargs: - return model, unused_kwargs + return return_tuple + (unused_kwargs,) else: - return model + return return_tuple if len(return_tuple) > 1 else model @classmethod def get_config_dict( @@ -272,7 +282,7 @@ class ConfigMixin: # remove general kwargs if present in dict if "kwargs" in expected_keys: expected_keys.remove("kwargs") - # remove flax interal keys + # remove flax internal keys if hasattr(cls, "_flax_internal_args"): for arg in cls._flax_internal_args: expected_keys.remove(arg) @@ -446,6 +456,9 @@ def flax_register_to_config(cls): # Make sure init_kwargs override default kwargs new_kwargs = {**default_kwargs, **init_kwargs} + # dtype should be part of `init_kwargs`, but not `new_kwargs` + if "dtype" in new_kwargs: + new_kwargs.pop("dtype") # Get positional arguments aligned with kwargs for i, arg in enumerate(args): diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index f6fb397303..09a7baad56 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -2,7 +2,7 @@ # 1. modify the `_deps` dict in setup.py # 2. run `make deps_table_update`` deps = { - "Pillow": "Pillow", + "Pillow": "Pillow<10.0", "accelerate": "accelerate>=0.11.0", "black": "black==22.8", "datasets": "datasets", @@ -15,8 +15,10 @@ deps = { "isort": "isort>=5.5.4", "jax": "jax>=0.2.8,!=0.3.2,<=0.3.6", "jaxlib": "jaxlib>=0.1.65,<=0.3.6", - "modelcards": "modelcards==0.1.4", + "modelcards": "modelcards>=0.1.4", "numpy": "numpy", + "onnxruntime": "onnxruntime", + "onnxruntime-gpu": "onnxruntime-gpu", "pytest": "pytest", "pytest-timeout": "pytest-timeout", "pytest-xdist": "pytest-xdist", @@ -25,5 +27,6 @@ deps = { "requests": "requests", "tensorboard": "tensorboard", "torch": "torch>=1.4", + "torchvision": "torchvision", "transformers": "transformers>=4.21.0", } diff --git a/src/diffusers/modeling_flax_pytorch_utils.py b/src/diffusers/modeling_flax_pytorch_utils.py new file mode 100644 index 0000000000..9c7a5de2ad --- /dev/null +++ b/src/diffusers/modeling_flax_pytorch_utils.py @@ -0,0 +1,117 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch - Flax general utilities.""" +import re + +import jax.numpy as jnp +from flax.traverse_util import flatten_dict, unflatten_dict +from jax.random import PRNGKey + +from .utils import logging + + +logger = logging.get_logger(__name__) + + +def rename_key(key): + regex = r"\w+[.]\d+" + pats = re.findall(regex, key) + for pat in pats: + key = key.replace(pat, "_".join(pat.split("."))) + return key + + +##################### +# PyTorch => Flax # +##################### + +# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69 +# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py +def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict): + """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary""" + + # conv norm or layer norm + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) + if ( + any("norm" in str_ for str_ in pt_tuple_key) + and (pt_tuple_key[-1] == "bias") + and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict) + and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict) + ): + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) + return renamed_pt_tuple_key, pt_tensor + elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict: + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) + return renamed_pt_tuple_key, pt_tensor + + # embedding + if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict: + pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) + return renamed_pt_tuple_key, pt_tensor + + # conv layer + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) + if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4: + pt_tensor = pt_tensor.transpose(2, 3, 1, 0) + return renamed_pt_tuple_key, pt_tensor + + # linear layer + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) + if pt_tuple_key[-1] == "weight": + pt_tensor = pt_tensor.T + return renamed_pt_tuple_key, pt_tensor + + # old PyTorch layer norm weight + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",) + if pt_tuple_key[-1] == "gamma": + return renamed_pt_tuple_key, pt_tensor + + # old PyTorch layer norm bias + renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",) + if pt_tuple_key[-1] == "beta": + return renamed_pt_tuple_key, pt_tensor + + return pt_tuple_key, pt_tensor + + +def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42): + # Step 1: Convert pytorch tensor to numpy + pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} + + # Step 2: Since the model is stateless, get random Flax params + random_flax_params = flax_model.init_weights(PRNGKey(init_key)) + + random_flax_state_dict = flatten_dict(random_flax_params) + flax_state_dict = {} + + # Need to change some parameters name to match Flax names + for pt_key, pt_tensor in pt_state_dict.items(): + renamed_pt_key = rename_key(pt_key) + pt_tuple_key = tuple(renamed_pt_key.split(".")) + + # Correctly rename weight parameters + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict) + + if flax_key in random_flax_state_dict: + if flax_tensor.shape != random_flax_state_dict[flax_key].shape: + raise ValueError( + f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " + f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." + ) + + # also add unexpected weight so that warning is thrown + flax_state_dict[flax_key] = jnp.asarray(flax_tensor) + + return unflatten_dict(flax_state_dict) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index f195462eca..7f1d65e2ed 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -27,12 +27,18 @@ from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from requests import HTTPError -from .modeling_utils import WEIGHTS_NAME -from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging +from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax +from .modeling_utils import load_state_dict +from .utils import ( + CONFIG_NAME, + DIFFUSERS_CACHE, + FLAX_WEIGHTS_NAME, + HUGGINGFACE_CO_RESOLVE_ENDPOINT, + WEIGHTS_NAME, + logging, +) -FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" - logger = logging.get_logger(__name__) @@ -45,7 +51,7 @@ class FlaxModelMixin: """ config_name = CONFIG_NAME _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] - _flax_internal_args = ["name", "parent"] + _flax_internal_args = ["name", "parent", "dtype"] @classmethod def _from_config(cls, config, **kwargs): @@ -245,6 +251,8 @@ class FlaxModelMixin: The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. + from_pt (`bool`, *optional*, defaults to `False`): + Load the model weights from a PyTorch checkpoint save file. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., `output_attentions=True`). Behaves differently depending on whether a `config` is provided or @@ -272,6 +280,7 @@ class FlaxModelMixin: config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) + from_pt = kwargs.pop("from_pt", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) @@ -294,32 +303,44 @@ class FlaxModelMixin: local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, + subfolder=subfolder, # model args dtype=dtype, **kwargs, ) # Load model - if os.path.isdir(pretrained_model_name_or_path): - if os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): + pretrained_path_with_subfolder = ( + pretrained_model_name_or_path + if subfolder is None + else os.path.join(pretrained_model_name_or_path, subfolder) + ) + if os.path.isdir(pretrained_path_with_subfolder): + if from_pt: + if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)): + raise EnvironmentError( + f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} " + ) + model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME) + elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)): # Load from a Flax checkpoint - model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) - # At this stage we don't have a weight file so we will raise an error. - elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): + model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME) + # Check if pytorch weights exist instead + elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)): raise EnvironmentError( - f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " - "but there is a file for PyTorch weights." + f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model" + " using `from_pt=True`." ) else: raise EnvironmentError( f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " - f"{pretrained_model_name_or_path}." + f"{pretrained_path_with_subfolder}." ) else: try: model_file = hf_hub_download( pretrained_model_name_or_path, - filename=FLAX_WEIGHTS_NAME, + filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME, cache_dir=cache_dir, force_download=force_download, proxies=proxies, @@ -369,25 +390,32 @@ class FlaxModelMixin: f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}." ) - try: - with open(model_file, "rb") as state_f: - state = from_bytes(cls, state_f.read()) - except (UnpicklingError, msgpack.exceptions.ExtraData) as e: + if from_pt: + # Step 1: Get the pytorch file + pytorch_model_file = load_state_dict(model_file) + + # Step 2: Convert the weights + state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model) + else: try: - with open(model_file) as f: - if f.read().startswith("version"): - raise OSError( - "You seem to have cloned a repository without having git-lfs installed. Please" - " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" - " folder you cloned." - ) - else: - raise ValueError from e - except (UnicodeDecodeError, ValueError): - raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") - # make sure all arrays are stored as jnp.ndarray - # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: - # https://github.com/google/flax/issues/1261 + with open(model_file, "rb") as state_f: + state = from_bytes(cls, state_f.read()) + except (UnpicklingError, msgpack.exceptions.ExtraData) as e: + try: + with open(model_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please" + " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" + " folder you cloned." + ) + else: + raise ValueError from e + except (UnicodeDecodeError, ValueError): + raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") + # make sure all arrays are stored as jnp.ndarray + # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: + # https://github.com/google/flax/issues/1261 state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state) # flatten dicts @@ -408,7 +436,7 @@ class FlaxModelMixin: ) cls._missing_keys = missing_keys - # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not + # Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # matching the weights in the model. mismatched_keys = [] for key in state.keys(): diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index ef1ead9ecf..659f2ee8a6 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -15,6 +15,7 @@ # limitations under the License. import os +from functools import partial from typing import Callable, List, Optional, Tuple, Union import torch @@ -24,10 +25,7 @@ from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from requests import HTTPError -from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging - - -WEIGHTS_NAME = "diffusion_pytorch_model.bin" +from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, logging logger = logging.get_logger(__name__) @@ -124,10 +122,42 @@ class ModelMixin(torch.nn.Module): """ config_name = CONFIG_NAME _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] + _supports_gradient_checkpointing = False def __init__(self): super().__init__() + @property + def is_gradient_checkpointing(self) -> bool: + """ + Whether gradient checkpointing is activated for this model or not. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) + + def enable_gradient_checkpointing(self): + """ + Activates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + if not self._supports_gradient_checkpointing: + raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") + self.apply(partial(self._set_gradient_checkpointing, value=True)) + + def disable_gradient_checkpointing(self): + """ + Deactivates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + if self._supports_gradient_checkpointing: + self.apply(partial(self._set_gradient_checkpointing, value=False)) + def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index e0ac5c8d54..1242ad6fca 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -12,6 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .unet_2d import UNet2DModel -from .unet_2d_condition import UNet2DConditionModel -from .vae import AutoencoderKL, VQModel +from ..utils import is_flax_available, is_torch_available + + +if is_torch_available(): + from .unet_2d import UNet2DModel + from .unet_2d_condition import UNet2DConditionModel + from .vae import AutoencoderKL, VQModel + +if is_flax_available(): + from .unet_2d_condition_flax import FlaxUNet2DConditionModel + from .vae_flax import FlaxAutoencoderKL diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index e4cedbff8c..25e1ea28dc 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -249,13 +249,15 @@ class CrossAttention(nn.Module): return tensor def forward(self, hidden_states, context=None, mask=None): - batch_size, sequence_length, dim = hidden_states.shape + batch_size, sequence_length, _ = hidden_states.shape query = self.to_q(hidden_states) context = context if context is not None else hidden_states key = self.to_k(context) value = self.to_v(context) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) key = self.reshape_heads_to_batch_dim(key) value = self.reshape_heads_to_batch_dim(value) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 918c7469a7..1745265b91 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -17,6 +17,22 @@ import jax.numpy as jnp class FlaxAttentionBlock(nn.Module): + r""" + A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762 + + Parameters: + query_dim (:obj:`int`): + Input hidden states dimension + heads (:obj:`int`, *optional*, defaults to 8): + Number of heads + dim_head (:obj:`int`, *optional*, defaults to 64): + Hidden states dimension inside each head + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + + """ query_dim: int heads: int = 8 dim_head: int = 64 @@ -32,7 +48,7 @@ class FlaxAttentionBlock(nn.Module): self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k") self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v") - self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out") + self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0") def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape @@ -74,6 +90,23 @@ class FlaxAttentionBlock(nn.Module): class FlaxBasicTransformerBlock(nn.Module): + r""" + A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in: + https://arxiv.org/abs/1706.03762 + + + Parameters: + dim (:obj:`int`): + Inner hidden states dimension + n_heads (:obj:`int`): + Number of heads + d_head (:obj:`int`): + Hidden states dimension inside each head + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ dim: int n_heads: int d_head: int @@ -82,9 +115,9 @@ class FlaxBasicTransformerBlock(nn.Module): def setup(self): # self attention - self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) + self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) # cross attention - self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) + self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) @@ -93,12 +126,12 @@ class FlaxBasicTransformerBlock(nn.Module): def __call__(self, hidden_states, context, deterministic=True): # self attention residual = hidden_states - hidden_states = self.self_attn(self.norm1(hidden_states), deterministic=deterministic) + hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic) hidden_states = hidden_states + residual # cross attention residual = hidden_states - hidden_states = self.cross_attn(self.norm2(hidden_states), context, deterministic=deterministic) + hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic) hidden_states = hidden_states + residual # feed forward @@ -110,6 +143,25 @@ class FlaxBasicTransformerBlock(nn.Module): class FlaxSpatialTransformer(nn.Module): + r""" + A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in: + https://arxiv.org/pdf/1506.02025.pdf + + + Parameters: + in_channels (:obj:`int`): + Input number of channels + n_heads (:obj:`int`): + Number of heads + d_head (:obj:`int`): + Hidden states dimension inside each head + depth (:obj:`int`, *optional*, defaults to 1): + Number of transformers block + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ in_channels: int n_heads: int d_head: int @@ -144,7 +196,6 @@ class FlaxSpatialTransformer(nn.Module): def __call__(self, hidden_states, context, deterministic=True): batch, height, width, channels = hidden_states.shape - # import ipdb; ipdb.set_trace() residual = hidden_states hidden_states = self.norm(hidden_states) hidden_states = self.proj_in(hidden_states) @@ -163,18 +214,56 @@ class FlaxSpatialTransformer(nn.Module): class FlaxGluFeedForward(nn.Module): + r""" + Flax module that encapsulates two Linear layers separated by a gated linear unit activation from: + https://arxiv.org/abs/2002.05202 + + Parameters: + dim (:obj:`int`): + Inner hidden states dimension + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + dim: int + dropout: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + # The second linear layer needs to be called + # net_2 for now to match the index of the Sequential layer + self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype) + self.net_2 = nn.Dense(self.dim, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic=True): + hidden_states = self.net_0(hidden_states) + hidden_states = self.net_2(hidden_states) + return hidden_states + + +class FlaxGEGLU(nn.Module): + r""" + Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from + https://arxiv.org/abs/2002.05202. + + Parameters: + dim (:obj:`int`): + Input hidden states dimension + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ dim: int dropout: float = 0.0 dtype: jnp.dtype = jnp.float32 def setup(self): inner_dim = self.dim * 4 - self.dense1 = nn.Dense(inner_dim * 2, dtype=self.dtype) - self.dense2 = nn.Dense(self.dim, dtype=self.dtype) + self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype) def __call__(self, hidden_states, deterministic=True): - hidden_states = self.dense1(hidden_states) + hidden_states = self.proj(hidden_states) hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2) - hidden_states = hidden_linear * nn.gelu(hidden_gelu) - hidden_states = self.dense2(hidden_states) - return hidden_states + return hidden_linear * nn.gelu(hidden_gelu) diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py index 63442ab997..e2d607499c 100644 --- a/src/diffusers/models/embeddings_flax.py +++ b/src/diffusers/models/embeddings_flax.py @@ -19,7 +19,7 @@ import jax.numpy as jnp # This is like models.embeddings.get_timestep_embedding (PyTorch) but # less general (only handles the case we currently need). -def get_sinusoidal_embeddings(timesteps, embedding_dim): +def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: float = 1): """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. @@ -29,7 +29,7 @@ def get_sinusoidal_embeddings(timesteps, embedding_dim): embeddings. :return: an [N x dim] tensor of positional embeddings. """ half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) + emb = math.log(10000) / (half_dim - freq_shift) emb = jnp.exp(jnp.arange(half_dim) * -emb) emb = timesteps[:, None] * emb[None, :] emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1) @@ -37,6 +37,15 @@ def get_sinusoidal_embeddings(timesteps, embedding_dim): class FlaxTimestepEmbedding(nn.Module): + r""" + Time step Embedding Module. Learns embeddings for input time steps. + + Args: + time_embed_dim (`int`, *optional*, defaults to `32`): + Time step embedding dimension + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ time_embed_dim: int = 32 dtype: jnp.dtype = jnp.float32 @@ -49,8 +58,16 @@ class FlaxTimestepEmbedding(nn.Module): class FlaxTimesteps(nn.Module): + r""" + Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239 + + Args: + dim (`int`, *optional*, defaults to `32`): + Time step embedding dimension + """ dim: int = 32 + freq_shift: float = 1 @nn.compact def __call__(self, timesteps): - return get_sinusoidal_embeddings(timesteps, self.dim) + return get_sinusoidal_embeddings(timesteps, self.dim, freq_shift=self.freq_shift) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 55ae42c2e3..97f3c02a8c 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -149,7 +149,6 @@ class FirUpsample2D(nn.Module): stride = (factor, factor) # Determine data dimensions. - stride = [1, 1, factor, factor] output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW) output_padding = ( output_shape[0] - (x.shape[2] - 1) * stride[0] - convH, @@ -161,7 +160,7 @@ class FirUpsample2D(nn.Module): # Transpose weights. weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) - weight = weight[..., ::-1, ::-1].permute(0, 2, 1, 3, 4) + weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index cdb0462194..5e3ee091c3 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -3,12 +3,21 @@ from typing import Optional, Tuple, Union import torch import torch.nn as nn +import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from ..utils import BaseOutput from .embeddings import TimestepEmbedding, Timesteps -from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block +from .unet_blocks import ( + CrossAttnDownBlock2D, + CrossAttnUpBlock2D, + DownBlock2D, + UNetMidBlock2DCrossAttn, + UpBlock2D, + get_down_block, + get_up_block, +) @dataclass @@ -54,6 +63,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. """ + _supports_gradient_checkpointing = True + @register_to_config def __init__( self, @@ -188,6 +199,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): if hasattr(block, "attentions") and block.attentions is not None: block.set_attention_slice(slice_size) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): + module.gradient_checkpointing = value + def forward( self, sample: torch.FloatTensor, @@ -234,7 +249,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): for downsample_block in self.down_blocks: if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: sample, res_samples = downsample_block( - hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 1ac68e10c1..9e84da0687 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -1,6 +1,6 @@ -from dataclasses import dataclass from typing import Tuple, Union +import flax import flax.linen as nn import jax import jax.numpy as jnp @@ -19,7 +19,7 @@ from .unet_blocks_flax import ( ) -@dataclass +@flax.struct.dataclass class FlaxUNet2DConditionOutput(BaseOutput): """ Args: @@ -39,10 +39,23 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library implements for all the models (such as downloading or saving, etc.) + Also, this model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. + + Finally, this model supports inherent JAX features such as: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + Parameters: - sample_size (`int`, *optional*): The size of the input sample. - in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + sample_size (`int`, *optional*): + The size of the input sample. + in_channels (`int`, *optional*, defaults to 4): + The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): + The number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D" @@ -51,10 +64,14 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D" block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. - attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. - cross_attention_dim (`int`, *optional*, defaults to 768): The dimension of the cross attention features. - dropout (`float`, *optional*, defaults to 0): Dropout probability for down, up and bottleneck blocks. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + attention_head_dim (`int`, *optional*, defaults to 8): + The dimension of the attention heads. + cross_attention_dim (`int`, *optional*, defaults to 768): + The dimension of the cross attention features. + dropout (`float`, *optional*, defaults to 0): + Dropout probability for down, up and bottleneck blocks. """ sample_size: int = 32 @@ -73,10 +90,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): cross_attention_dim: int = 1280 dropout: float = 0.0 dtype: jnp.dtype = jnp.float32 + freq_shift: int = 0 def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: # init input tensors - sample_shape = (1, self.sample_size, self.sample_size, self.in_channels) + sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample = jnp.zeros(sample_shape, dtype=jnp.float32) timesteps = jnp.ones((1,), dtype=jnp.int32) encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32) @@ -100,7 +118,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ) # time - self.time_proj = FlaxTimesteps(block_out_channels[0]) + self.time_proj = FlaxTimesteps(block_out_channels[0], freq_shift=self.config.freq_shift) self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) # down @@ -214,10 +232,17 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): When returning a tuple, the first element is the sample tensor. """ # 1. time + if not isinstance(timesteps, jnp.ndarray): + timesteps = jnp.array([timesteps], dtype=jnp.int32) + elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0: + timesteps = timesteps.astype(dtype=jnp.float32) + timesteps = jnp.expand_dims(timesteps, 0) + t_emb = self.time_proj(timesteps) t_emb = self.time_embedding(t_emb) # 2. pre-process + sample = jnp.transpose(sample, (0, 2, 3, 1)) sample = self.conv_in(sample) # 3. down @@ -251,6 +276,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): sample = self.conv_norm_out(sample) sample = nn.silu(sample) sample = self.conv_out(sample) + sample = jnp.transpose(sample, (0, 3, 1, 2)) if not return_dict: return (sample,) diff --git a/src/diffusers/models/unet_blocks.py b/src/diffusers/models/unet_blocks.py index 1fee670c91..f42389b985 100644 --- a/src/diffusers/models/unet_blocks.py +++ b/src/diffusers/models/unet_blocks.py @@ -527,6 +527,8 @@ class CrossAttnDownBlock2D(nn.Module): else: self.downsamplers = None + self.gradient_checkpointing = False + def set_attention_slice(self, slice_size): if slice_size is not None and self.attn_num_head_channels % slice_size != 0: raise ValueError( @@ -546,8 +548,22 @@ class CrossAttnDownBlock2D(nn.Module): output_states = () for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, context=encoder_hidden_states) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn), hidden_states, encoder_hidden_states + ) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=encoder_hidden_states) + output_states += (hidden_states,) if self.downsamplers is not None: @@ -609,11 +625,24 @@ class DownBlock2D(nn.Module): else: self.downsamplers = None + self.gradient_checkpointing = False + def forward(self, hidden_states, temb=None): output_states = () for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) if self.downsamplers is not None: @@ -1072,6 +1101,8 @@ class CrossAttnUpBlock2D(nn.Module): else: self.upsamplers = None + self.gradient_checkpointing = False + def set_attention_slice(self, slice_size): if slice_size is not None and self.attn_num_head_channels % slice_size != 0: raise ValueError( @@ -1087,15 +1118,36 @@ class CrossAttnUpBlock2D(nn.Module): for attn in self.attentions: attn._set_attention_slice(slice_size) - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None): + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + ): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, context=encoder_hidden_states) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn), hidden_states, encoder_hidden_states + ) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=encoder_hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -1150,6 +1202,8 @@ class UpBlock2D(nn.Module): else: self.upsamplers = None + self.gradient_checkpointing = False + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): for resnet in self.resnets: # pop res hidden states @@ -1157,7 +1211,17 @@ class UpBlock2D(nn.Module): res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: diff --git a/src/diffusers/models/unet_blocks_flax.py b/src/diffusers/models/unet_blocks_flax.py index ce67eb12b1..c23b7844c3 100644 --- a/src/diffusers/models/unet_blocks_flax.py +++ b/src/diffusers/models/unet_blocks_flax.py @@ -19,6 +19,26 @@ from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D class FlaxCrossAttnDownBlock2D(nn.Module): + r""" + Cross Attention 2D Downsizing block - original architecture from Unet transformers: + https://arxiv.org/abs/2103.06104 + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers + attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): + Number of attention heads of each spatial transformer block + add_downsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add downsampling layer before each final output + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ in_channels: int out_channels: int dropout: float = 0.0 @@ -55,7 +75,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module): self.attentions = attentions if self.add_downsample: - self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype) + self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): output_states = () @@ -66,13 +86,30 @@ class FlaxCrossAttnDownBlock2D(nn.Module): output_states += (hidden_states,) if self.add_downsample: - hidden_states = self.downsample(hidden_states) + hidden_states = self.downsamplers_0(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class FlaxDownBlock2D(nn.Module): + r""" + Flax 2D downsizing block + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers + add_downsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add downsampling layer before each final output + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ in_channels: int out_channels: int dropout: float = 0.0 @@ -96,7 +133,7 @@ class FlaxDownBlock2D(nn.Module): self.resnets = resnets if self.add_downsample: - self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype) + self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, temb, deterministic=True): output_states = () @@ -106,13 +143,33 @@ class FlaxDownBlock2D(nn.Module): output_states += (hidden_states,) if self.add_downsample: - hidden_states = self.downsample(hidden_states) + hidden_states = self.downsamplers_0(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class FlaxCrossAttnUpBlock2D(nn.Module): + r""" + Cross Attention 2D Upsampling block - original architecture from Unet transformers: + https://arxiv.org/abs/2103.06104 + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers + attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): + Number of attention heads of each spatial transformer block + add_upsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add upsampling layer before each final output + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ in_channels: int out_channels: int prev_output_channel: int @@ -151,7 +208,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module): self.attentions = attentions if self.add_upsample: - self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype) + self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True): for resnet, attn in zip(self.resnets, self.attentions): @@ -164,12 +221,31 @@ class FlaxCrossAttnUpBlock2D(nn.Module): hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) if self.add_upsample: - hidden_states = self.upsample(hidden_states) + hidden_states = self.upsamplers_0(hidden_states) return hidden_states class FlaxUpBlock2D(nn.Module): + r""" + Flax 2D upsampling block + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + prev_output_channel (:obj:`int`): + Output channels from the previous block + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers + add_downsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add downsampling layer before each final output + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ in_channels: int out_channels: int prev_output_channel: int @@ -196,7 +272,7 @@ class FlaxUpBlock2D(nn.Module): self.resnets = resnets if self.add_upsample: - self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype) + self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True): for resnet in self.resnets: @@ -208,12 +284,27 @@ class FlaxUpBlock2D(nn.Module): hidden_states = resnet(hidden_states, temb, deterministic=deterministic) if self.add_upsample: - hidden_states = self.upsample(hidden_states) + hidden_states = self.upsamplers_0(hidden_states) return hidden_states class FlaxUNetMidBlock2DCrossAttn(nn.Module): + r""" + Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104 + + Parameters: + in_channels (:obj:`int`): + Input channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of attention blocks layers + attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): + Number of attention heads of each spatial transformer block + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ in_channels: int dropout: float = 0.0 num_layers: int = 1 diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py new file mode 100644 index 0000000000..b3261b11cf --- /dev/null +++ b/src/diffusers/models/vae_flax.py @@ -0,0 +1,812 @@ +# JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers + +import math +from functools import partial +from typing import Tuple + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict + +from ..configuration_utils import ConfigMixin, flax_register_to_config +from ..modeling_flax_utils import FlaxModelMixin +from ..utils import BaseOutput + + +@flax.struct.dataclass +class FlaxDecoderOutput(BaseOutput): + """ + Output of decoding method. + + Args: + sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): + Decoded output sample of the model. Output of the last layer of the model. + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + sample: jnp.ndarray + + +@flax.struct.dataclass +class FlaxAutoencoderKLOutput(BaseOutput): + """ + Output of AutoencoderKL encoding method. + + Args: + latent_dist (`FlaxDiagonalGaussianDistribution`): + Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution`. + `FlaxDiagonalGaussianDistribution` allows for sampling latents from the distribution. + """ + + latent_dist: "FlaxDiagonalGaussianDistribution" + + +class FlaxUpsample2D(nn.Module): + """ + Flax implementation of 2D Upsample layer + + Args: + in_channels (`int`): + Input channels + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + in_channels: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv = nn.Conv( + self.in_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + batch, height, width, channels = hidden_states.shape + hidden_states = jax.image.resize( + hidden_states, + shape=(batch, height * 2, width * 2, channels), + method="nearest", + ) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class FlaxDownsample2D(nn.Module): + """ + Flax implementation of 2D Downsample layer + + Args: + in_channels (`int`): + Input channels + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + in_channels: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv = nn.Conv( + self.in_channels, + kernel_size=(3, 3), + strides=(2, 2), + padding="VALID", + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim + hidden_states = jnp.pad(hidden_states, pad_width=pad) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class FlaxResnetBlock2D(nn.Module): + """ + Flax implementation of 2D Resnet Block. + + Args: + in_channels (`int`): + Input channels + out_channels (`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`): + Whether to use `nin_shortcut`. This activates a new layer inside ResNet block + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + + in_channels: int + out_channels: int = None + dropout: float = 0.0 + use_nin_shortcut: bool = None + dtype: jnp.dtype = jnp.float32 + + def setup(self): + out_channels = self.in_channels if self.out_channels is None else self.out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6) + self.conv1 = nn.Conv( + out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6) + self.dropout_layer = nn.Dropout(self.dropout) + self.conv2 = nn.Conv( + out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut + + self.conv_shortcut = None + if use_nin_shortcut: + self.conv_shortcut = nn.Conv( + out_channels, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + + def __call__(self, hidden_states, deterministic=True): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = nn.swish(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states = nn.swish(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + return hidden_states + residual + + +class FlaxAttentionBlock(nn.Module): + r""" + Flax Convolutional based multi-head attention block for diffusion-based VAE. + + Parameters: + channels (:obj:`int`): + Input channels + num_head_channels (:obj:`int`, *optional*, defaults to `None`): + Number of attention heads + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + + """ + channels: int + num_head_channels: int = None + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1 + + dense = partial(nn.Dense, self.channels, dtype=self.dtype) + + self.group_norm = nn.GroupNorm(num_groups=32, epsilon=1e-6) + self.query, self.key, self.value = dense(), dense(), dense() + self.proj_attn = dense() + + def transpose_for_scores(self, projection): + new_projection_shape = projection.shape[:-1] + (self.num_heads, -1) + # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) + new_projection = projection.reshape(new_projection_shape) + # (B, T, H, D) -> (B, H, T, D) + new_projection = jnp.transpose(new_projection, (0, 2, 1, 3)) + return new_projection + + def __call__(self, hidden_states): + residual = hidden_states + batch, height, width, channels = hidden_states.shape + + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.reshape((batch, height * width, channels)) + + query = self.query(hidden_states) + key = self.key(hidden_states) + value = self.value(hidden_states) + + # transpose + query = self.transpose_for_scores(query) + key = self.transpose_for_scores(key) + value = self.transpose_for_scores(value) + + # compute attentions + scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) + attn_weights = jnp.einsum("...qc,...kc->...qk", query * scale, key * scale) + attn_weights = nn.softmax(attn_weights, axis=-1) + + # attend to values + hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights) + + hidden_states = jnp.transpose(hidden_states, (0, 2, 1, 3)) + new_hidden_states_shape = hidden_states.shape[:-2] + (self.channels,) + hidden_states = hidden_states.reshape(new_hidden_states_shape) + + hidden_states = self.proj_attn(hidden_states) + hidden_states = hidden_states.reshape((batch, height, width, channels)) + hidden_states = hidden_states + residual + return hidden_states + + +class FlaxDownEncoderBlock2D(nn.Module): + r""" + Flax Resnet blocks-based Encoder block for diffusion-based VAE. + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of Resnet layer block + add_downsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add downsample layer + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + in_channels: int + out_channels: int + dropout: float = 0.0 + num_layers: int = 1 + add_downsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + for i in range(self.num_layers): + in_channels = self.in_channels if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=in_channels, + out_channels=self.out_channels, + dropout=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + self.resnets = resnets + + if self.add_downsample: + self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic=True): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, deterministic=deterministic) + + if self.add_downsample: + hidden_states = self.downsamplers_0(hidden_states) + + return hidden_states + + +class FlaxUpEncoderBlock2D(nn.Module): + r""" + Flax Resnet blocks-based Encoder block for diffusion-based VAE. + + Parameters: + in_channels (:obj:`int`): + Input channels + out_channels (:obj:`int`): + Output channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of Resnet layer block + add_downsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add downsample layer + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + in_channels: int + out_channels: int + dropout: float = 0.0 + num_layers: int = 1 + add_upsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + for i in range(self.num_layers): + in_channels = self.in_channels if i == 0 else self.out_channels + res_block = FlaxResnetBlock2D( + in_channels=in_channels, + out_channels=self.out_channels, + dropout=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + self.resnets = resnets + + if self.add_upsample: + self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic=True): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, deterministic=deterministic) + + if self.add_upsample: + hidden_states = self.upsamplers_0(hidden_states) + + return hidden_states + + +class FlaxUNetMidBlock2D(nn.Module): + r""" + Flax Unet Mid-Block module. + + Parameters: + in_channels (:obj:`int`): + Input channels + dropout (:obj:`float`, *optional*, defaults to 0.0): + Dropout rate + num_layers (:obj:`int`, *optional*, defaults to 1): + Number of Resnet layer block + attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`): + Number of attention heads for each attention block + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + in_channels: int + dropout: float = 0.0 + num_layers: int = 1 + attn_num_head_channels: int = 1 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + # there is always at least one resnet + resnets = [ + FlaxResnetBlock2D( + in_channels=self.in_channels, + out_channels=self.in_channels, + dropout=self.dropout, + dtype=self.dtype, + ) + ] + + attentions = [] + + for _ in range(self.num_layers): + attn_block = FlaxAttentionBlock( + channels=self.in_channels, num_head_channels=self.attn_num_head_channels, dtype=self.dtype + ) + attentions.append(attn_block) + + res_block = FlaxResnetBlock2D( + in_channels=self.in_channels, + out_channels=self.in_channels, + dropout=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + self.resnets = resnets + self.attentions = attentions + + def __call__(self, hidden_states, deterministic=True): + hidden_states = self.resnets[0](hidden_states, deterministic=deterministic) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states) + hidden_states = resnet(hidden_states, deterministic=deterministic) + + return hidden_states + + +class FlaxEncoder(nn.Module): + r""" + Flax Implementation of VAE Encoder. + + This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. + + Finally, this model supports inherent JAX features such as: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + in_channels (:obj:`int`, *optional*, defaults to 3): + Input channels + out_channels (:obj:`int`, *optional*, defaults to 3): + Output channels + down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): + DownEncoder block type + block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): + Tuple containing the number of output channels for each block + layers_per_block (:obj:`int`, *optional*, defaults to `2`): + Number of Resnet layer for each block + norm_num_groups (:obj:`int`, *optional*, defaults to `2`): + norm num group + act_fn (:obj:`str`, *optional*, defaults to `silu`): + Activation function + double_z (:obj:`bool`, *optional*, defaults to `False`): + Whether to double the last output channels + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + Parameters `dtype` + """ + in_channels: int = 3 + out_channels: int = 3 + down_block_types: Tuple[str] = ("DownEncoderBlock2D",) + block_out_channels: Tuple[int] = (64,) + layers_per_block: int = 2 + norm_num_groups: int = 32 + act_fn: str = "silu" + double_z: bool = False + dtype: jnp.dtype = jnp.float32 + + def setup(self): + block_out_channels = self.block_out_channels + # in + self.conv_in = nn.Conv( + block_out_channels[0], + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + # downsampling + down_blocks = [] + output_channel = block_out_channels[0] + for i, _ in enumerate(self.down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = FlaxDownEncoderBlock2D( + in_channels=input_channel, + out_channels=output_channel, + num_layers=self.layers_per_block, + add_downsample=not is_final_block, + dtype=self.dtype, + ) + down_blocks.append(down_block) + self.down_blocks = down_blocks + + # middle + self.mid_block = FlaxUNetMidBlock2D( + in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype + ) + + # end + conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels + self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6) + self.conv_out = nn.Conv( + conv_out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + def __call__(self, sample, deterministic: bool = True): + # in + sample = self.conv_in(sample) + + # downsampling + for block in self.down_blocks: + sample = block(sample, deterministic=deterministic) + + # middle + sample = self.mid_block(sample, deterministic=deterministic) + + # end + sample = self.conv_norm_out(sample) + sample = nn.swish(sample) + sample = self.conv_out(sample) + + return sample + + +class FlaxDecoder(nn.Module): + r""" + Flax Implementation of VAE Decoder. + + This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. + + Finally, this model supports inherent JAX features such as: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + in_channels (:obj:`int`, *optional*, defaults to 3): + Input channels + out_channels (:obj:`int`, *optional*, defaults to 3): + Output channels + up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): + UpDecoder block type + block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): + Tuple containing the number of output channels for each block + layers_per_block (:obj:`int`, *optional*, defaults to `2`): + Number of Resnet layer for each block + norm_num_groups (:obj:`int`, *optional*, defaults to `32`): + norm num group + act_fn (:obj:`str`, *optional*, defaults to `silu`): + Activation function + double_z (:obj:`bool`, *optional*, defaults to `False`): + Whether to double the last output channels + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + parameters `dtype` + """ + in_channels: int = 3 + out_channels: int = 3 + up_block_types: Tuple[str] = ("UpDecoderBlock2D",) + block_out_channels: int = (64,) + layers_per_block: int = 2 + norm_num_groups: int = 32 + act_fn: str = "silu" + dtype: jnp.dtype = jnp.float32 + + def setup(self): + block_out_channels = self.block_out_channels + + # z to block_in + self.conv_in = nn.Conv( + block_out_channels[-1], + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + # middle + self.mid_block = FlaxUNetMidBlock2D( + in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype + ) + + # upsampling + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + up_blocks = [] + for i, _ in enumerate(self.up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = FlaxUpEncoderBlock2D( + in_channels=prev_output_channel, + out_channels=output_channel, + num_layers=self.layers_per_block + 1, + add_upsample=not is_final_block, + dtype=self.dtype, + ) + up_blocks.append(up_block) + prev_output_channel = output_channel + + self.up_blocks = up_blocks + + # end + self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6) + self.conv_out = nn.Conv( + self.out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + def __call__(self, sample, deterministic: bool = True): + # z to block_in + sample = self.conv_in(sample) + + # middle + sample = self.mid_block(sample, deterministic=deterministic) + + # upsampling + for block in self.up_blocks: + sample = block(sample, deterministic=deterministic) + + sample = self.conv_norm_out(sample) + sample = nn.swish(sample) + sample = self.conv_out(sample) + + return sample + + +class FlaxDiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + # Last axis to account for channels-last + self.mean, self.logvar = jnp.split(parameters, 2, axis=-1) + self.logvar = jnp.clip(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = jnp.exp(0.5 * self.logvar) + self.var = jnp.exp(self.logvar) + if self.deterministic: + self.var = self.std = jnp.zeros_like(self.mean) + + def sample(self, key): + return self.mean + self.std * jax.random.normal(key, self.mean.shape) + + def kl(self, other=None): + if self.deterministic: + return jnp.array([0.0]) + + if other is None: + return 0.5 * jnp.sum(self.mean**2 + self.var - 1.0 - self.logvar, axis=[1, 2, 3]) + + return 0.5 * jnp.sum( + jnp.square(self.mean - other.mean) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, + axis=[1, 2, 3], + ) + + def nll(self, sample, axis=[1, 2, 3]): + if self.deterministic: + return jnp.array([0.0]) + + logtwopi = jnp.log(2.0 * jnp.pi) + return 0.5 * jnp.sum(logtwopi + self.logvar + jnp.square(sample - self.mean) / self.var, axis=axis) + + def mode(self): + return self.mean + + +@flax_register_to_config +class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): + r""" + Flax Implementation of Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational + Bayes by Diederik P. Kingma and Max Welling. + + This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + general usage and behavior. + + Finally, this model supports inherent JAX features such as: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + in_channels (:obj:`int`, *optional*, defaults to 3): + Input channels + out_channels (:obj:`int`, *optional*, defaults to 3): + Output channels + down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): + DownEncoder block type + up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): + UpDecoder block type + block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): + Tuple containing the number of output channels for each block + layers_per_block (:obj:`int`, *optional*, defaults to `2`): + Number of Resnet layer for each block + act_fn (:obj:`str`, *optional*, defaults to `silu`): + Activation function + latent_channels (:obj:`int`, *optional*, defaults to `4`): + Latent space channels + norm_num_groups (:obj:`int`, *optional*, defaults to `32`): + Norm num group + sample_size (:obj:`int`, *optional*, defaults to `32`): + Sample input size + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): + parameters `dtype` + """ + in_channels: int = 3 + out_channels: int = 3 + down_block_types: Tuple[str] = ("DownEncoderBlock2D",) + up_block_types: Tuple[str] = ("UpDecoderBlock2D",) + block_out_channels: Tuple[int] = (64,) + layers_per_block: int = 1 + act_fn: str = "silu" + latent_channels: int = 4 + norm_num_groups: int = 32 + sample_size: int = 32 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.encoder = FlaxEncoder( + in_channels=self.config.in_channels, + out_channels=self.config.latent_channels, + down_block_types=self.config.down_block_types, + block_out_channels=self.config.block_out_channels, + layers_per_block=self.config.layers_per_block, + act_fn=self.config.act_fn, + norm_num_groups=self.config.norm_num_groups, + double_z=True, + dtype=self.dtype, + ) + self.decoder = FlaxDecoder( + in_channels=self.config.latent_channels, + out_channels=self.config.out_channels, + up_block_types=self.config.up_block_types, + block_out_channels=self.config.block_out_channels, + layers_per_block=self.config.layers_per_block, + norm_num_groups=self.config.norm_num_groups, + act_fn=self.config.act_fn, + dtype=self.dtype, + ) + self.quant_conv = nn.Conv( + 2 * self.config.latent_channels, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + self.post_quant_conv = nn.Conv( + self.config.latent_channels, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + + def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: + # init input tensors + sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) + sample = jnp.zeros(sample_shape, dtype=jnp.float32) + + params_rng, dropout_rng, gaussian_rng = jax.random.split(rng, 3) + rngs = {"params": params_rng, "dropout": dropout_rng, "gaussian": gaussian_rng} + + return self.init(rngs, sample)["params"] + + def encode(self, sample, deterministic: bool = True, return_dict: bool = True): + sample = jnp.transpose(sample, (0, 2, 3, 1)) + + hidden_states = self.encoder(sample, deterministic=deterministic) + moments = self.quant_conv(hidden_states) + posterior = FlaxDiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return FlaxAutoencoderKLOutput(latent_dist=posterior) + + def decode(self, latents, deterministic: bool = True, return_dict: bool = True): + if latents.shape[-1] != self.config.latent_channels: + latents = jnp.transpose(latents, (0, 2, 3, 1)) + + hidden_states = self.post_quant_conv(latents) + hidden_states = self.decoder(hidden_states, deterministic=deterministic) + + hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2)) + + if not return_dict: + return (hidden_states,) + + return FlaxDecoderOutput(sample=hidden_states) + + def __call__(self, sample, sample_posterior=False, deterministic: bool = True, return_dict: bool = True): + posterior = self.encode(sample, deterministic=deterministic, return_dict=return_dict) + if sample_posterior: + rng = self.make_rng("gaussian") + hidden_states = posterior.latent_dist.sample(rng) + else: + hidden_states = posterior.latent_dist.mode() + + sample = self.decode(hidden_states, return_dict=return_dict).sample + + if not return_dict: + return (sample,) + + return FlaxDecoderOutput(sample=sample) diff --git a/src/diffusers/onnx_utils.py b/src/diffusers/onnx_utils.py index 3c2a0b4829..2282f411ae 100644 --- a/src/diffusers/onnx_utils.py +++ b/src/diffusers/onnx_utils.py @@ -24,16 +24,13 @@ import numpy as np from huggingface_hub import hf_hub_download -from .utils import is_onnx_available, logging +from .utils import ONNX_WEIGHTS_NAME, is_onnx_available, logging if is_onnx_available(): import onnxruntime as ort -ONNX_WEIGHTS_NAME = "model.onnx" - - logger = logging.get_logger(__name__) @@ -49,7 +46,7 @@ class OnnxRuntimeModel: return self.model.run(None, inputs) @staticmethod - def load_model(path: Union[str, Path], provider=None): + def load_model(path: Union[str, Path], provider=None, sess_options=None): """ Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider` @@ -63,7 +60,7 @@ class OnnxRuntimeModel: logger.info("No onnxruntime provider specified, using CPUExecutionProvider") provider = "CPUExecutionProvider" - return ort.InferenceSession(path, providers=[provider]) + return ort.InferenceSession(path, providers=[provider], sess_options=sess_options) def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs): """ @@ -117,6 +114,7 @@ class OnnxRuntimeModel: cache_dir: Optional[str] = None, file_name: Optional[str] = None, provider: Optional[str] = None, + sess_options: Optional["ort.SessionOptions"] = None, **kwargs, ): """ @@ -146,7 +144,9 @@ class OnnxRuntimeModel: model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME # load model from local directory if os.path.isdir(model_id): - model = OnnxRuntimeModel.load_model(os.path.join(model_id, model_file_name), provider=provider) + model = OnnxRuntimeModel.load_model( + os.path.join(model_id, model_file_name), provider=provider, sess_options=sess_options + ) kwargs["model_save_dir"] = Path(model_id) # load model from hub else: @@ -161,7 +161,7 @@ class OnnxRuntimeModel: ) kwargs["model_save_dir"] = Path(model_cache_path).parent kwargs["latest_model_name"] = Path(model_cache_path).name - model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider) + model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider, sess_options=sess_options) return cls(model=model, **kwargs) @classmethod diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py new file mode 100644 index 0000000000..9ea94ee0f2 --- /dev/null +++ b/src/diffusers/pipeline_flax_utils.py @@ -0,0 +1,474 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +import os +from typing import Dict, List, Optional, Union + +import numpy as np + +import flax +import PIL +from flax.core.frozen_dict import FrozenDict +from huggingface_hub import snapshot_download +from PIL import Image +from tqdm.auto import tqdm + +from .configuration_utils import ConfigMixin +from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin +from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin +from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging + + +if is_transformers_available(): + from transformers import FlaxPreTrainedModel + +INDEX_FILE = "diffusion_flax_model.bin" + + +logger = logging.get_logger(__name__) + + +LOADABLE_CLASSES = { + "diffusers": { + "FlaxModelMixin": ["save_pretrained", "from_pretrained"], + "SchedulerMixin": ["save_config", "from_config"], + "FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"], + }, + "transformers": { + "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], + "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], + "FlaxPreTrainedModel": ["save_pretrained", "from_pretrained"], + "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], + }, +} + +ALL_IMPORTABLE_CLASSES = {} +for library in LOADABLE_CLASSES: + ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) + + +class DummyChecker: + def __init__(self): + self.dummy = True + + +def import_flax_or_no_model(module, class_name): + try: + # 1. First make sure that if a Flax object is present, import this one + class_obj = getattr(module, "Flax" + class_name) + except AttributeError: + # 2. If this doesn't work, it's not a model and we don't append "Flax" + class_obj = getattr(module, class_name) + except AttributeError: + raise ValueError(f"Neither Flax{class_name} nor {class_name} exist in {module}") + + return class_obj + + +@flax.struct.dataclass +class FlaxImagePipelineOutput(BaseOutput): + """ + Output class for image pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +class FlaxDiffusionPipeline(ConfigMixin): + r""" + Base class for all models. + + [`FlaxDiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion + pipelines and handles methods for loading, downloading and saving models as well as a few methods common to all + pipelines to: + + - enabling/disabling the progress bar for the denoising iteration + + Class attributes: + + - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all + components of the diffusion pipeline. + """ + config_name = "model_index.json" + + def register_modules(self, **kwargs): + # import it here to avoid circular import + from diffusers import pipelines + + for name, module in kwargs.items(): + # retrieve library + library = module.__module__.split(".")[0] + + # check if the module is a pipeline module + pipeline_dir = module.__module__.split(".")[-2] + path = module.__module__.split(".") + is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) + + # if library is not in LOADABLE_CLASSES, then it is a custom module. + # Or if it's a pipeline module, then the module is inside the pipeline + # folder so we set the library to module name. + if library not in LOADABLE_CLASSES or is_pipeline_module: + library = pipeline_dir + + # retrieve class_name + class_name = module.__class__.__name__ + + register_dict = {name: (library, class_name)} + + # save model index config + self.register_to_config(**register_dict) + + # set models + setattr(self, name, module) + + def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union[Dict, FrozenDict]): + # TODO: handle inference_state + """ + Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to + a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading + method. The pipeline can easily be re-loaded using the `[`~FlaxDiffusionPipeline.from_pretrained`]` class + method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + """ + self.save_config(save_directory) + + model_index_dict = dict(self.config) + model_index_dict.pop("_class_name") + model_index_dict.pop("_diffusers_version") + model_index_dict.pop("_module", None) + + for pipeline_component_name in model_index_dict.keys(): + sub_model = getattr(self, pipeline_component_name) + model_cls = sub_model.__class__ + + save_method_name = None + # search for the model's base class in LOADABLE_CLASSES + for library_name, library_classes in LOADABLE_CLASSES.items(): + library = importlib.import_module(library_name) + for base_class, save_load_methods in library_classes.items(): + class_candidate = getattr(library, base_class) + if issubclass(model_cls, class_candidate): + # if we found a suitable base class in LOADABLE_CLASSES then grab its save method + save_method_name = save_load_methods[0] + break + if save_method_name is not None: + break + + # TODO(Patrick, Suraj): to delete after + if isinstance(sub_model, DummyChecker): + continue + + save_method = getattr(sub_model, save_method_name) + expects_params = "params" in set(inspect.signature(save_method).parameters.keys()) + + if expects_params: + save_method( + os.path.join(save_directory, pipeline_component_name), params=params[pipeline_component_name] + ) + else: + save_method(os.path.join(save_directory, pipeline_component_name)) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + r""" + Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights. + + The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on + https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like + `CompVis/ldm-text2im-large-256`. + - A path to a *directory* containing pipeline weights saved using + [`~FlaxDiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`. + dtype (`str` or `jnp.dtype`, *optional*): + Override the default `jnp.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. specify the folder name here. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the + specific pipeline class. The overwritten components are then directly passed to the pipelines + `__init__` method. See example below for more information. + + + + Passing `use_auth_token=True`` is required when you want to use a private model, *e.g.* + `"CompVis/stable-diffusion-v1-4"` + + + + + + Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use + this method in a firewalled environment. + + + + Examples: + + ```py + >>> from diffusers import FlaxDiffusionPipeline + + >>> # Download pipeline from huggingface.co and cache. + >>> pipeline = FlaxDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") + + >>> # Download pipeline that requires an authorization token + >>> # For more information on access tokens, please refer to this section + >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens) + >>> pipeline = FlaxDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True) + + >>> # Download pipeline, but overwrite scheduler + >>> from diffusers import LMSDiscreteScheduler + + >>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + >>> pipeline = FlaxDiffusionPipeline.from_pretrained( + ... "CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True + ... ) + ``` + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + from_pt = kwargs.pop("from_pt", False) + dtype = kwargs.pop("dtype", None) + + # 1. Download the checkpoints and configs + # use snapshot download here to get it working from from_pretrained + if not os.path.isdir(pretrained_model_name_or_path): + config_dict = cls.get_config_dict( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + ) + # make sure we only download sub-folders and `diffusers` filenames + folder_names = [k for k in config_dict.keys() if not k.startswith("_")] + allow_patterns = [os.path.join(k, "*") for k in folder_names] + allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name] + + # download all allow_patterns + cached_folder = snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + allow_patterns=allow_patterns, + ) + else: + cached_folder = pretrained_model_name_or_path + + config_dict = cls.get_config_dict(cached_folder) + + # 2. Load the pipeline class, if using custom module then load it from the hub + # if we load from explicit class, let's use it + if cls != FlaxDiffusionPipeline: + pipeline_class = cls + else: + diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) + pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) + + # some modules can be passed directly to the init + # in this case they are already instantiated in `kwargs` + # extract them here + expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + + init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) + + init_kwargs = {} + + # inference_params + params = {} + + # import it here to avoid circular import + from diffusers import pipelines + + # 3. Load each module in the pipeline + for name, (library_name, class_name) in init_dict.items(): + # TODO(Patrick, Suraj) - delete later + if class_name == "DummyChecker": + library_name = "stable_diffusion" + class_name = "FlaxStableDiffusionSafetyChecker" + + is_pipeline_module = hasattr(pipelines, library_name) + loaded_sub_model = None + + # if the model is in a pipeline module, then we load it from the pipeline + if name in passed_class_obj: + # 1. check that passed_class_obj has correct parent class + if not is_pipeline_module: + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + importable_classes = LOADABLE_CLASSES[library_name] + class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} + + expected_class_obj = None + for class_name, class_candidate in class_candidates.items(): + if issubclass(class_obj, class_candidate): + expected_class_obj = class_candidate + + if not issubclass(passed_class_obj[name].__class__, expected_class_obj): + raise ValueError( + f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" + f" {expected_class_obj}" + ) + else: + logger.warn( + f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" + " has the correct type" + ) + + # set passed class object + loaded_sub_model = passed_class_obj[name] + elif is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + if from_pt: + class_obj = import_flax_or_no_model(pipeline_module, class_name) + else: + class_obj = getattr(pipeline_module, class_name) + + importable_classes = ALL_IMPORTABLE_CLASSES + class_candidates = {c: class_obj for c in importable_classes.keys()} + else: + # else we just import it from the library. + library = importlib.import_module(library_name) + if from_pt: + class_obj = import_flax_or_no_model(library, class_name) + else: + class_obj = getattr(library, class_name) + + importable_classes = LOADABLE_CLASSES[library_name] + class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} + + if loaded_sub_model is None: + load_method_name = None + for class_name, class_candidate in class_candidates.items(): + if issubclass(class_obj, class_candidate): + load_method_name = importable_classes[class_name][1] + + load_method = getattr(class_obj, load_method_name) + + # check if the module is in a subdirectory + if os.path.isdir(os.path.join(cached_folder, name)): + loadable_folder = os.path.join(cached_folder, name) + else: + loaded_sub_model = cached_folder + + if issubclass(class_obj, FlaxModelMixin): + loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype) + params[name] = loaded_params + elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel): + # make sure we don't initialize the weights to save time + if name == "safety_checker": + loaded_sub_model = DummyChecker() + loaded_params = {} + elif from_pt: + # TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here + loaded_sub_model = load_method(loadable_folder, from_pt=from_pt) + loaded_params = loaded_sub_model.params + del loaded_sub_model._params + else: + loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False) + params[name] = loaded_params + elif issubclass(class_obj, SchedulerMixin): + loaded_sub_model, scheduler_state = load_method(loadable_folder) + params[name] = scheduler_state + else: + loaded_sub_model = load_method(loadable_folder) + + init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) + + model = pipeline_class(**init_kwargs, dtype=dtype) + return model, params + + @staticmethod + def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + # TODO: make it compatible with jax.lax + def progress_bar(self, iterable): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + return tqdm(iterable, **self._progress_bar_config) + + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 847513bf15..fb8801bc95 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -30,10 +30,8 @@ from PIL import Image from tqdm.auto import tqdm from .configuration_utils import ConfigMixin -from .modeling_utils import WEIGHTS_NAME -from .onnx_utils import ONNX_WEIGHTS_NAME from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME -from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, logging +from .utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging INDEX_FILE = "diffusion_pytorch_model.bin" @@ -237,8 +235,8 @@ class DiffusionPipeline(ConfigMixin): kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the - specific pipeline class. The overritten components are then directly passed to the pipelines `__init__` - method. See example below for more information. + specific pipeline class. The overwritten components are then directly passed to the pipelines + `__init__` method. See example below for more information. @@ -284,6 +282,7 @@ class DiffusionPipeline(ConfigMixin): revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None) provider = kwargs.pop("provider", None) + sess_options = kwargs.pop("sess_options", None) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained @@ -341,6 +340,10 @@ class DiffusionPipeline(ConfigMixin): # 3. Load each module in the pipeline for name, (library_name, class_name) in init_dict.items(): + # 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names + if class_name.startswith("Flax"): + class_name = class_name[4:] + is_pipeline_module = hasattr(pipelines, library_name) loaded_sub_model = None @@ -396,6 +399,7 @@ class DiffusionPipeline(ConfigMixin): loading_kwargs["torch_dtype"] = torch_dtype if issubclass(class_obj, diffusers.OnnxRuntimeModel): loading_kwargs["provider"] = provider + loading_kwargs["sess_options"] = sess_options # check if the module is in a subdirectory if os.path.isdir(os.path.join(cached_folder, name)): diff --git a/src/diffusers/pipelines/README.md b/src/diffusers/pipelines/README.md index 7957a8c364..3462f5ff51 100644 --- a/src/diffusers/pipelines/README.md +++ b/src/diffusers/pipelines/README.md @@ -73,7 +73,7 @@ not be used for training. If you want to store the gradients during the forward We are more than happy about any contribution to the officially supported pipelines 🤗. We aspire all of our pipelines to be **self-contained**, **easy-to-tweak**, **beginner-friendly** and for **one-purpose-only**. -- **Self-contained**: A pipeline shall be as self-contained as possible. More specifically, this means that all functionality should be either directly defined in the pipeline file iteslf, should be inherited from (and only from) the [`DiffusionPipeline` class](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L56) or be directly attached to the model and scheduler components of the pipeline. +- **Self-contained**: A pipeline shall be as self-contained as possible. More specifically, this means that all functionality should be either directly defined in the pipeline file itself, should be inherited from (and only from) the [`DiffusionPipeline` class](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L56) or be directly attached to the model and scheduler components of the pipeline. - **Easy-to-use**: Pipelines should be extremely easy to use - one should be able to load the pipeline and use it for its designated task, *e.g.* text-to-image generation, in just a couple of lines of code. Most logic including pre-processing, an unrolled diffusion loop, and post-processing should all happen inside the `__call__` method. diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 3e2aeb4fb2..8e3c8592a2 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -1,4 +1,4 @@ -from ..utils import is_onnx_available, is_transformers_available +from ..utils import is_flax_available, is_onnx_available, is_torch_available, is_transformers_available from .ddim import DDIMPipeline from .ddpm import DDPMPipeline from .latent_diffusion_uncond import LDMPipeline @@ -7,7 +7,7 @@ from .score_sde_ve import ScoreSdeVePipeline from .stochastic_karras_ve import KarrasVePipeline -if is_transformers_available(): +if is_torch_available() and is_transformers_available(): from .latent_diffusion import LDMTextToImagePipeline from .stable_diffusion import ( StableDiffusionImg2ImgPipeline, @@ -17,3 +17,6 @@ if is_transformers_available(): if is_transformers_available() and is_onnx_available(): from .stable_diffusion import StableDiffusionOnnxPipeline + +if is_transformers_available() and is_flax_available(): + from .stable_diffusion import FlaxStableDiffusionPipeline diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index 8caa11dbdf..4a4f29be7f 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -27,7 +27,7 @@ class LDMTextToImagePipeline(DiffusionPipeline): vqvae ([`VQModel`]): Vector-quantized (VQ) Model to encode and decode images to and from latent representations. bert ([`LDMBertModel`]): - Text-encoder model based on [BERT](ttps://huggingface.co/docs/transformers/model_doc/bert) architecture. + Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. tokenizer (`transformers.BertTokenizer`): Tokenizer of class [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). @@ -397,7 +397,7 @@ class LDMBertAttention(nn.Module): attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned aross GPUs when using tensor-parallelism. + # partitioned across GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim) attn_output = self.out_proj(attn_output) @@ -478,7 +478,7 @@ class LDMBertEncoderLayer(nn.Module): class LDMBertPreTrainedModel(PreTrainedModel): config_class = LDMBertConfig base_model_prefix = "model" - supports_gradient_checkpointing = True + _supports_gradient_checkpointing = True _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"] def _init_weights(self, module): diff --git a/src/diffusers/pipelines/stable_diffusion/README.md b/src/diffusers/pipelines/stable_diffusion/README.md index 0e6cab2e11..3a600c5859 100644 --- a/src/diffusers/pipelines/stable_diffusion/README.md +++ b/src/diffusers/pipelines/stable_diffusion/README.md @@ -67,7 +67,7 @@ pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" with autocast("cuda"): - image = pipe(prompt).sample[0] + image = pipe(prompt).images[0] image.save("astronaut_rides_horse.png") ``` @@ -89,7 +89,7 @@ pipe = StableDiffusionPipeline.from_pretrained( prompt = "a photo of an astronaut riding a horse on mars" with autocast("cuda"): - image = pipe(prompt).sample[0] + image = pipe(prompt).images[0] image.save("astronaut_rides_horse.png") ``` @@ -115,7 +115,7 @@ pipe = StableDiffusionPipeline.from_pretrained( prompt = "a photo of an astronaut riding a horse on mars" with autocast("cuda"): - image = pipe(prompt).sample[0] + image = pipe(prompt).images[0] image.save("astronaut_rides_horse.png") ``` diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 5ffda93f17..1016ce69e4 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -6,7 +6,7 @@ import numpy as np import PIL from PIL import Image -from ...utils import BaseOutput, is_onnx_available, is_transformers_available +from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_transformers_available @dataclass @@ -35,3 +35,26 @@ if is_transformers_available(): if is_transformers_available() and is_onnx_available(): from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline + +if is_transformers_available() and is_flax_available(): + import flax + + @flax.struct.dataclass + class FlaxStableDiffusionPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: List[bool] + + from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline + from .safety_checker_flax import FlaxStableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py new file mode 100644 index 0000000000..870a715ef5 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -0,0 +1,217 @@ +from typing import Dict, List, Optional, Union + +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict +from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel + +from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel +from ...pipeline_flax_utils import FlaxDiffusionPipeline +from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler +from . import FlaxStableDiffusionPipelineOutput +from .safety_checker_flax import FlaxStableDiffusionSafetyChecker + + +class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`FlaxAutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`FlaxCLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of + [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`]. + safety_checker ([`FlaxStableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: FlaxAutoencoderKL, + text_encoder: FlaxCLIPTextModel, + tokenizer: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + scheduler: Union[FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler], + safety_checker: FlaxStableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + scheduler = scheduler.set_format("np") + self.dtype = dtype + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + def prepare_inputs(self, prompt: Union[str, List[str]]): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + return text_input.input_ids + + def __call__( + self, + prompt_ids: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.PRNGKey, + num_inference_steps: Optional[int] = 50, + height: Optional[int] = 512, + width: Optional[int] = 512, + guidance_scale: Optional[float] = 7.5, + latents: Optional[jnp.array] = None, + return_dict: bool = True, + debug: bool = False, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`jnp.array`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of + a plain tuple. + + Returns: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + # get prompt text embeddings + text_embeddings = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] + + # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` + # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` + batch_size = prompt_ids.shape[0] + + max_length = prompt_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0] + context = jnp.concatenate([uncond_embeddings, text_embeddings]) + + latents_shape = ( + batch_size, + self.unet.in_channels, + self.unet.sample_size, + self.unet.sample_size, + ) + if latents is None: + latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + def loop_body(step, args): + latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": params["unet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + ).sample + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents, scheduler_state + + scheduler_state = self.scheduler.set_timesteps(params["scheduler"], num_inference_steps=num_inference_steps) + + if debug: + # run with python for loop + for i in range(num_inference_steps): + latents, scheduler_state = loop_body(i, (latents, scheduler_state)) + else: + latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + # TODO: check when flax vae gets merged into main + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample + + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + + # image = jnp.asarray(image).transpose(0, 2, 3, 1) + # run safety checker + # TODO: check when flax safety checker gets merged into main + # safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") + # image, has_nsfw_concept = self.safety_checker( + # images=image, clip_input=safety_checker_input.pixel_values, params=params["safety_params"] + # ) + has_nsfw_concept = False + + if not return_dict: + return (image, has_nsfw_concept) + + return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 9f1211b430..411f27308a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -36,7 +36,7 @@ class StableDiffusionPipeline(DiffusionPipeline): A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offsensive or harmful. + Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. @@ -58,7 +58,7 @@ class StableDiffusionPipeline(DiffusionPipeline): if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: warnings.warn( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure " + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" @@ -278,8 +278,8 @@ class StableDiffusionPipeline(DiffusionPipeline): image = image.cpu().permute(0, 2, 3, 1).numpy() # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index e7adb4d1a3..46299bf3b3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -48,7 +48,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offsensive or harmful. + Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. @@ -70,7 +70,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: warnings.warn( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure " + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" @@ -213,7 +213,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): # add noise to latents using the timesteps noise = torch.randn(init_latents.shape, generator=generator, device=self.device) - init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to(self.device) + init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) # get prompt text embeddings text_input = self.tokenizer( @@ -265,8 +265,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): sigma = self.scheduler.sigmas[t_index] # the model input needs to be scaled to match the continuous ODE formulation in K-LMS latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) - latent_model_input = latent_model_input.to(self.unet.dtype) - t = t.to(self.unet.dtype) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -284,14 +282,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): # scale and decode the image latents with vae latents = 1 / 0.18215 * latents - image = self.vae.decode(latents.to(self.vae.dtype)).sample + image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index b9ad36f1a2..7de7925a30 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -12,7 +12,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline -from ...schedulers import DDIMScheduler, PNDMScheduler +from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import logging from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -66,7 +66,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offsensive or harmful. + Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. @@ -78,7 +78,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler], + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, ): @@ -89,7 +89,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: warnings.warn( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure " + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" @@ -241,8 +241,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) + if isinstance(self.scheduler, LMSDiscreteScheduler): + timesteps = torch.tensor( + [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device + ) + else: + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) # add noise to latents using the timesteps noise = torch.randn(init_latents.shape, generator=generator, device=self.device) @@ -287,8 +292,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): latents = init_latents t_start = max(num_inference_steps - init_timestep + offset, 0) for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): + t_index = t_start + i # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[t_index] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -299,10 +309,15 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor(t_index)) + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t) - # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t) latents = (init_latents_proper * mask) + (latents * (1 - mask)) # scale and decode the image latents with vae @@ -313,8 +328,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): image = image.cpu().permute(0, 2, 3, 1).numpy() # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index 09de92eeb1..3eb8828cdb 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -48,20 +48,20 @@ class StableDiffusionSafetyChecker(PreTrainedModel): # at the cost of increasing the possibility of filtering benign images adjustment = 0.0 - for concet_idx in range(len(special_cos_dist[0])): - concept_cos = special_cos_dist[i][concet_idx] - concept_threshold = self.special_care_embeds_weights[concet_idx].item() - result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3) - if result_img["special_scores"][concet_idx] > 0: - result_img["special_care"].append({concet_idx, result_img["special_scores"][concet_idx]}) + for concept_idx in range(len(special_cos_dist[0])): + concept_cos = special_cos_dist[i][concept_idx] + concept_threshold = self.special_care_embeds_weights[concept_idx].item() + result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["special_scores"][concept_idx] > 0: + result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) adjustment = 0.01 - for concet_idx in range(len(cos_dist[0])): - concept_cos = cos_dist[i][concet_idx] - concept_threshold = self.concept_embeds_weights[concet_idx].item() - result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3) - if result_img["concept_scores"][concet_idx] > 0: - result_img["bad_concepts"].append(concet_idx) + for concept_idx in range(len(cos_dist[0])): + concept_cos = cos_dist[i][concept_idx] + concept_threshold = self.concept_embeds_weights[concept_idx].item() + result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["concept_scores"][concept_idx] > 0: + result_img["bad_concepts"].append(concept_idx) result.append(result_img) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py new file mode 100644 index 0000000000..b3cd8eef02 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py @@ -0,0 +1,147 @@ +import warnings +from typing import Optional, Tuple + +import numpy as np + +import jax +import jax.numpy as jnp +from flax import linen as nn +from flax.core.frozen_dict import FrozenDict +from transformers import CLIPConfig, FlaxPreTrainedModel +from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule + + +def jax_cosine_distance(emb_1, emb_2, eps=1e-12): + norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T + norm_emb_2 = jnp.divide(emb_2.T, jnp.clip(jnp.linalg.norm(emb_2, axis=1), a_min=eps)).T + return jnp.matmul(norm_emb_1, norm_emb_2.T) + + +class FlaxStableDiffusionSafetyCheckerModule(nn.Module): + config: CLIPConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.vision_model = FlaxCLIPVisionModule(self.config.vision_config) + self.visual_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype) + + self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.config.projection_dim)) + self.special_care_embeds = self.param( + "special_care_embeds", jax.nn.initializers.ones, (3, self.config.projection_dim) + ) + + self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,)) + self.special_care_embeds_weights = self.param("special_care_embeds_weights", jax.nn.initializers.ones, (3,)) + + def __call__(self, clip_input): + pooled_output = self.vision_model(clip_input)[1] + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds) + cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds) + return special_cos_dist, cos_dist + + def filtered_with_scores(self, special_cos_dist, cos_dist, images): + batch_size = special_cos_dist.shape[0] + special_cos_dist = np.asarray(special_cos_dist) + cos_dist = np.asarray(cos_dist) + + result = [] + for i in range(batch_size): + result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} + + # increase this value to create a stronger `nfsw` filter + # at the cost of increasing the possibility of filtering benign image inputs + adjustment = 0.0 + + for concept_idx in range(len(special_cos_dist[0])): + concept_cos = special_cos_dist[i][concept_idx] + concept_threshold = self.special_care_embeds_weights[concept_idx].item() + result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["special_scores"][concept_idx] > 0: + result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) + adjustment = 0.01 + + for concept_idx in range(len(cos_dist[0])): + concept_cos = cos_dist[i][concept_idx] + concept_threshold = self.concept_embeds_weights[concept_idx].item() + result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["concept_scores"][concept_idx] > 0: + result_img["bad_concepts"].append(concept_idx) + + result.append(result_img) + + has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] + + images_was_copied = False + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if not images_was_copied: + images_was_copied = True + images = images.copy() + + images[idx] = np.zeros(images[idx].shape) # black image + + if any(has_nsfw_concepts): + warnings.warn( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead. Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + +class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel): + config_class = CLIPConfig + main_input_name = "clip_input" + module_class = FlaxStableDiffusionSafetyCheckerModule + + def __init__( + self, + config: CLIPConfig, + input_shape: Optional[Tuple] = None, + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + if input_shape is None: + input_shape = (1, 224, 224, 3) + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensor + clip_input = jax.random.normal(rng, input_shape) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init(rngs, clip_input)["params"] + + return random_params + + def __call__( + self, + clip_input, + params: dict = None, + ): + clip_input = jnp.transpose(clip_input, (0, 2, 3, 1)) + + return self.module.apply( + {"params": params or self.params}, + jnp.array(clip_input, dtype=jnp.float32), + rngs={}, + ) + + def filtered_with_scores(self, special_cos_dist, cos_dist, images, params: dict = None): + def _filtered_with_scores(module, special_cos_dist, cos_dist, images): + return module.filtered_with_scores(special_cos_dist, cos_dist, images) + + return self.module.apply( + {"params": params or self.params}, + special_cos_dist, + cos_dist, + images, + method=_filtered_with_scores, + ) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 32be871f0b..0613ffd41d 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -17,13 +17,33 @@ import math import warnings +from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class DDIMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -179,7 +199,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): use_clipped_model_output: bool = False, generator=None, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> Union[DDIMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). @@ -192,11 +212,11 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): eta (`float`): weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`): TODO generator: random number generator. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ @@ -222,6 +242,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called @@ -260,7 +281,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): if not return_dict: return (prev_sample,) - return SchedulerOutput(prev_sample=prev_sample) + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) def add_noise( self, diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index dd3c2ac85d..d81d666071 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -21,7 +21,6 @@ from typing import Optional, Tuple, Union import flax import jax.numpy as jnp -from jax import random from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import SchedulerMixin, SchedulerOutput @@ -60,11 +59,12 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: class DDIMSchedulerState: # setable values timesteps: jnp.ndarray + alphas_cumprod: jnp.ndarray num_inference_steps: Optional[int] = None @classmethod - def create(cls, num_train_timesteps: int): - return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1]) + def create(cls, num_train_timesteps: int, alphas_cumprod: jnp.ndarray): + return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1], alphas_cumprod=alphas_cumprod) @dataclass @@ -96,9 +96,19 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. set_alpha_to_one (`bool`, default `True`): - if alpha for final step is 1 or the final alpha of the "non-previous" one. + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ + @property + def has_state(self): + return True + @register_to_config def __init__( self, @@ -106,12 +116,9 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", - trained_betas: Optional[jnp.ndarray] = None, - clip_sample: bool = True, set_alpha_to_one: bool = True, + steps_offset: int = 0, ): - if trained_betas is not None: - self.betas = jnp.asarray(trained_betas) if beta_schedule == "linear": self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) elif beta_schedule == "scaled_linear": @@ -124,19 +131,24 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas - self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) + + # HACK for now - clean up later (PVP) + self._alphas_cumprod = jnp.cumprod(self.alphas, axis=0) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. - self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else float(self._alphas_cumprod[0]) - self.state = DDIMSchedulerState.create(num_train_timesteps=num_train_timesteps) + def create_state(self): + return DDIMSchedulerState.create( + num_train_timesteps=self.config.num_train_timesteps, alphas_cumprod=self._alphas_cumprod + ) - def _get_variance(self, timestep, prev_timestep): - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + def _get_variance(self, timestep, prev_timestep, alphas_cumprod): + alpha_prod_t = alphas_cumprod[timestep] + alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev @@ -144,9 +156,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): return variance - def set_timesteps( - self, state: DDIMSchedulerState, num_inference_steps: int, offset: int = 0 - ) -> DDIMSchedulerState: + def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int) -> DDIMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -155,9 +165,9 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): the `FlaxDDIMScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. - offset (`int`): - optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference. """ + offset = self.config.steps_offset + step_ratio = self.config.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 @@ -172,9 +182,6 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, - key: random.KeyArray, - eta: float = 0.0, - use_clipped_model_output: bool = False, return_dict: bool = True, ) -> Union[FlaxSchedulerOutput, Tuple]: """ @@ -213,44 +220,35 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" + # TODO(Patrick) - eta is always 0.0 for now, allow to be set in step function + eta = 0.0 + # 1. get previous step value (=t-1) prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps + alphas_cumprod = state.alphas_cumprod + # 2. compute alphas, betas - alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + alpha_prod_t = alphas_cumprod[timestep] + alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod) + beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - # 4. Clip "predicted x_0" - if self.config.clip_sample: - pred_original_sample = jnp.clip(pred_original_sample, -1, 1) - - # 5. compute variance: "sigma_t(η)" -> see formula (16) + # 4. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) - variance = self._get_variance(timestep, prev_timestep) + variance = self._get_variance(timestep, prev_timestep, alphas_cumprod) std_dev_t = eta * variance ** (0.5) - if use_clipped_model_output: - # the model_output is always re-derived from the clipped x_0 in Glide - model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) - - # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output - # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction - if eta > 0: - key = random.split(key, num=1) - noise = random.normal(key=key, shape=model_output.shape) - variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise - - prev_sample = prev_sample + variance - if not return_dict: return (prev_sample, state) @@ -263,9 +261,14 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): timesteps: jnp.ndarray, ) -> jnp.ndarray: sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod[:, None] + + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[:, None] noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index fac75bc43e..440b880385 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -1,4 +1,4 @@ -# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved. +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,13 +15,33 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math +from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class DDPMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -177,7 +197,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): predict_epsilon=True, generator=None, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> Union[DDPMSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). @@ -190,11 +210,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): predict_epsilon (`bool`): optional flag to use when model predicts the samples directly instead of the noise, epsilon. generator: random number generator. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ @@ -242,7 +262,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): if not return_dict: return (pred_prev_sample,) - return SchedulerOutput(prev_sample=pred_prev_sample) + return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) def add_noise( self, diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index f686a2a322..7c7b8d29ab 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -1,4 +1,4 @@ -# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved. +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -266,9 +266,14 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin): timesteps: jnp.ndarray, ) -> jnp.ndarray: sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod[..., None] + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None] noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index caf7625fb6..98dafc72a7 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -35,10 +35,14 @@ class KarrasVeOutput(BaseOutput): denoising loop. derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Derivative of predicted original image sample (x_0). + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor derivative: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None class KarrasVeScheduler(SchedulerMixin, ConfigMixin): @@ -106,7 +110,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() self.schedule = [ ( - self.config.sigma_max + self.config.sigma_max**2 * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1)) ) for i in self.timesteps @@ -153,7 +157,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): sigma_hat (`float`): TODO sigma_prev (`float`): TODO sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check). Returns: @@ -170,7 +174,9 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): if not return_dict: return (sample_prev, derivative) - return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative) + return KarrasVeOutput( + prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample + ) def step_correct( self, @@ -192,7 +198,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO derivative (`torch.FloatTensor` or `np.ndarray`): TODO - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class Returns: prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO @@ -205,7 +211,9 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): if not return_dict: return (sample_prev, derivative) - return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative) + return KarrasVeOutput( + prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample + ) def add_noise(self, original_samples, noise, timesteps): raise NotImplementedError() diff --git a/src/diffusers/schedulers/scheduling_karras_ve_flax.py b/src/diffusers/schedulers/scheduling_karras_ve_flax.py index fe71b3fd0e..c320b79e6d 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_karras_ve_flax.py @@ -47,7 +47,7 @@ class FlaxKarrasVeOutput(BaseOutput): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. derivative (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): - Derivate of predicted original image sample (x_0). + Derivative of predicted original image sample (x_0). state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class. """ @@ -113,7 +113,7 @@ class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin): timesteps = jnp.arange(0, num_inference_steps)[::-1].copy() schedule = [ ( - self.config.sigma_max + self.config.sigma_max**2 * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1)) ) for i in timesteps diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 5857ae70a8..1dd6dbda1e 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np @@ -20,7 +21,26 @@ import torch from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +@dataclass +class LMSDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): @@ -133,7 +153,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): sample: Union[torch.FloatTensor, np.ndarray], order: int = 4, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> Union[LMSDiscreteSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). @@ -144,12 +164,12 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): sample (`torch.FloatTensor` or `np.ndarray`): current instance of sample being created by diffusion process. order: coefficient for multi-step inference. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. + [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. """ sigma = self.sigmas[timestep] @@ -175,7 +195,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): if not return_dict: return (prev_sample,) - return SchedulerOutput(prev_sample=prev_sample) + return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) def add_noise( self, diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index 1431bdacf5..7f4c076b54 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -198,8 +198,11 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin): noise: jnp.ndarray, timesteps: jnp.ndarray, ) -> jnp.ndarray: - sigmas = self.match_shape(state.sigmas[timesteps], noise) - noisy_samples = original_samples + noise * sigmas + sigma = state.sigmas[timesteps].flatten() + while len(sigma.shape) < len(noise.shape): + sigma = sigma[..., None] + + noisy_samples = original_samples + noise * sigma return noisy_samples diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 8444d66804..8344505620 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math - # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +import math from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -59,7 +59,6 @@ class PNDMSchedulerState: # setable values _timesteps: jnp.ndarray num_inference_steps: Optional[int] = None - _offset: int = 0 prk_timesteps: Optional[jnp.ndarray] = None plms_timesteps: Optional[jnp.ndarray] = None timesteps: Optional[jnp.ndarray] = None @@ -104,8 +103,20 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): skip_prk_steps (`bool`): allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required before plms steps; defaults to `False`. + set_alpha_to_one (`bool`, default `False`): + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. """ + @property + def has_state(self): + return True + @register_to_config def __init__( self, @@ -115,6 +126,8 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): beta_schedule: str = "linear", trained_betas: Optional[jnp.ndarray] = None, skip_prk_steps: bool = False, + set_alpha_to_one: bool = False, + steps_offset: int = 0, ): if trained_betas is not None: self.betas = jnp.asarray(trained_betas) @@ -132,16 +145,17 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): self.alphas = 1.0 - self.betas self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) + self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + # For now we only support F-PNDM, i.e. the runge-kutta method # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # mainly at formula (9), (12), (13) and the Algorithm 2. self.pndm_order = 4 - self.state = PNDMSchedulerState.create(num_train_timesteps=num_train_timesteps) + def create_state(self): + return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps) - def set_timesteps( - self, state: PNDMSchedulerState, num_inference_steps: int, offset: int = 0 - ) -> PNDMSchedulerState: + def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> PNDMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -150,16 +164,15 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): the `FlaxPNDMScheduler` state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. - offset (`int`): - optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference. """ + offset = self.config.steps_offset + step_ratio = self.config.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio # rounding to avoid issues when num_inference_step is power of 3 - _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] - _timesteps = _timesteps + offset + _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round() + offset - state = state.replace(num_inference_steps=num_inference_steps, _offset=offset, _timesteps=_timesteps) + state = state.replace(num_inference_steps=num_inference_steps, _timesteps=_timesteps) if self.config.skip_prk_steps: # for some models like stable diffusion the prk steps can/should be skipped to @@ -254,7 +267,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ) diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2 - prev_timestep = max(timestep - diff_to_prev, state.prk_timesteps[-1]) + prev_timestep = timestep - diff_to_prev timestep = state.prk_timesteps[state.counter // 4 * 4] if state.counter % 4 == 0: @@ -274,7 +287,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): # cur_sample should not be `None` cur_sample = state.cur_sample if state.cur_sample is not None else sample - prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output, state=state) + prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) state = state.replace(counter=state.counter + 1) if not return_dict: @@ -320,7 +333,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): "for more information." ) - prev_timestep = max(timestep - self.config.num_train_timesteps // state.num_inference_steps, 0) + prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps if state.counter != 1: state = state.replace(ets=state.ets.append(model_output)) @@ -344,7 +357,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): 55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4] ) - prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output, state=state) + prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) state = state.replace(counter=state.counter + 1) if not return_dict: @@ -352,7 +365,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) - def _get_prev_sample(self, sample, timestep, timestep_prev, model_output, state): + def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf # this function computes x_(t−δ) using the formula of (9) # Note that x_t needs to be added to both sides of the equation @@ -365,8 +378,8 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): # sample -> x_t # model_output -> e_θ(x_t, t) # prev_sample -> x_(t−δ) - alpha_prod_t = self.alphas_cumprod[timestep + 1 - state._offset] - alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - state._offset] + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev @@ -395,9 +408,14 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): timesteps: jnp.ndarray, ) -> jnp.ndarray: sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod[..., None] + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None] noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples diff --git a/src/diffusers/schedulers/scheduling_sde_ve_flax.py b/src/diffusers/schedulers/scheduling_sde_ve_flax.py index e5860706aa..08fbe14732 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_sde_ve_flax.py @@ -192,14 +192,17 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x) # also equation 47 shows the analog from SDE models to ancestral sampling methods - drift = drift - diffusion[:, None, None, None] ** 2 * model_output + diffusion = diffusion.flatten() + while len(diffusion.shape) < len(sample.shape): + diffusion = diffusion[:, None] + drift = drift - diffusion**2 * model_output # equation 6: sample noise for the diffusion term of key = random.split(key, num=1) noise = random.normal(key=key, shape=sample.shape) prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep # TODO is the variable diffusion the correct scaling term for the noise? - prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g + prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g if not return_dict: return (prev_sample, prev_sample_mean, state) @@ -248,8 +251,11 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): step_size = step_size * jnp.ones(sample.shape[0]) # compute corrected sample: model_output term and noise term - prev_sample_mean = sample + step_size[:, None, None, None] * model_output - prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise + step_size = step_size.flatten() + while len(step_size.shape) < len(sample.shape): + step_size = step_size[:, None] + prev_sample_mean = sample + step_size * model_output + prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise if not return_dict: return (prev_sample, state) diff --git a/src/diffusers/testing_utils.py b/src/diffusers/testing_utils.py index 7daf2bc633..d3f6fa628d 100644 --- a/src/diffusers/testing_utils.py +++ b/src/diffusers/testing_utils.py @@ -1,7 +1,9 @@ import os import random +import re import unittest from distutils.util import strtobool +from pathlib import Path from typing import Union import torch @@ -92,3 +94,157 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: image = PIL.ImageOps.exif_transpose(image) image = image.convert("RGB") return image + + +# --- pytest conf functions --- # + +# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once +pytest_opt_registered = {} + + +def pytest_addoption_shared(parser): + """ + This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there. + + It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest` + option. + + """ + option = "--make-reports" + if option not in pytest_opt_registered: + parser.addoption( + option, + action="store", + default=False, + help="generate report files. The value of this option is used as a prefix to report names", + ) + pytest_opt_registered[option] = 1 + + +def pytest_terminal_summary_main(tr, id): + """ + Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current + directory. The report files are prefixed with the test suite name. + + This function emulates --duration and -rA pytest arguments. + + This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined + there. + + Args: + - tr: `terminalreporter` passed from `conftest.py` + - id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is + needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other. + + NB: this functions taps into a private _pytest API and while unlikely, it could break should + pytest do internal changes - also it calls default internal methods of terminalreporter which + can be hijacked by various `pytest-` plugins and interfere. + + """ + from _pytest.config import create_terminal_writer + + if not len(id): + id = "tests" + + config = tr.config + orig_writer = config.get_terminal_writer() + orig_tbstyle = config.option.tbstyle + orig_reportchars = tr.reportchars + + dir = "reports" + Path(dir).mkdir(parents=True, exist_ok=True) + report_files = { + k: f"{dir}/{id}_{k}.txt" + for k in [ + "durations", + "errors", + "failures_long", + "failures_short", + "failures_line", + "passes", + "stats", + "summary_short", + "warnings", + ] + } + + # custom durations report + # note: there is no need to call pytest --durations=XX to get this separate report + # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66 + dlist = [] + for replist in tr.stats.values(): + for rep in replist: + if hasattr(rep, "duration"): + dlist.append(rep) + if dlist: + dlist.sort(key=lambda x: x.duration, reverse=True) + with open(report_files["durations"], "w") as f: + durations_min = 0.05 # sec + f.write("slowest durations\n") + for i, rep in enumerate(dlist): + if rep.duration < durations_min: + f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted") + break + f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n") + + def summary_failures_short(tr): + # expecting that the reports were --tb=long (default) so we chop them off here to the last frame + reports = tr.getreports("failed") + if not reports: + return + tr.write_sep("=", "FAILURES SHORT STACK") + for rep in reports: + msg = tr._getfailureheadline(rep) + tr.write_sep("_", msg, red=True, bold=True) + # chop off the optional leading extra frames, leaving only the last one + longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S) + tr._tw.line(longrepr) + # note: not printing out any rep.sections to keep the report short + + # use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each + # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814 + # note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g. + # pytest-instafail does that) + + # report failures with line/short/long styles + config.option.tbstyle = "auto" # full tb + with open(report_files["failures_long"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_failures() + + # config.option.tbstyle = "short" # short tb + with open(report_files["failures_short"], "w") as f: + tr._tw = create_terminal_writer(config, f) + summary_failures_short(tr) + + config.option.tbstyle = "line" # one line per error + with open(report_files["failures_line"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_failures() + + with open(report_files["errors"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_errors() + + with open(report_files["warnings"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_warnings() # normal warnings + tr.summary_warnings() # final warnings + + tr.reportchars = "wPpsxXEf" # emulate -rA (used in summary_passes() and short_test_summary()) + with open(report_files["passes"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_passes() + + with open(report_files["summary_short"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.short_test_summary() + + with open(report_files["stats"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_stats() + + # restore: + tr._tw = orig_writer + tr.reportchars = orig_reportchars + config.option.tbstyle = orig_tbstyle diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index c00a28e105..b63dbd2b28 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -47,6 +47,9 @@ default_cache_path = os.path.join(hf_cache_home, "diffusers") CONFIG_NAME = "config.json" +WEIGHTS_NAME = "diffusion_pytorch_model.bin" +FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" +ONNX_WEIGHTS_NAME = "model.onnx" HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" DIFFUSERS_CACHE = default_cache_path DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" diff --git a/src/diffusers/utils/dummy_flax_and_transformers_objects.py b/src/diffusers/utils/dummy_flax_and_transformers_objects.py new file mode 100644 index 0000000000..51ee3b1848 --- /dev/null +++ b/src/diffusers/utils/dummy_flax_and_transformers_objects.py @@ -0,0 +1,11 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +# flake8: noqa + +from ..utils import DummyObject, requires_backends + + +class FlaxStableDiffusionPipeline(metaclass=DummyObject): + _backends = ["flax", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax", "transformers"]) diff --git a/src/diffusers/utils/dummy_flax_objects.py b/src/diffusers/utils/dummy_flax_objects.py index 9615afb6f9..1e3ac002a6 100644 --- a/src/diffusers/utils/dummy_flax_objects.py +++ b/src/diffusers/utils/dummy_flax_objects.py @@ -11,6 +11,27 @@ class FlaxModelMixin(metaclass=DummyObject): requires_backends(self, ["flax"]) +class FlaxUNet2DConditionModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxAutoencoderKL(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxDiffusionPipeline(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxDDIMScheduler(metaclass=DummyObject): _backends = ["flax"] @@ -46,13 +67,6 @@ class FlaxPNDMScheduler(metaclass=DummyObject): requires_backends(self, ["flax"]) -class FlaxUNet2DConditionModel(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - class FlaxScoreSdeVeScheduler(metaclass=DummyObject): _backends = ["flax"] diff --git a/src/diffusers/utils/logging.py b/src/diffusers/utils/logging.py index 7771a5a5bf..8c1c77d10b 100644 --- a/src/diffusers/utils/logging.py +++ b/src/diffusers/utils/logging.py @@ -266,7 +266,7 @@ def reset_format() -> None: def warning_advice(self, *args, **kwargs): """ - This method is identical to `logger.warninging()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this + This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this warning will not be printed """ no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False) diff --git a/src/diffusers/utils/outputs.py b/src/diffusers/utils/outputs.py index b02f62d02d..45d483ce7b 100644 --- a/src/diffusers/utils/outputs.py +++ b/src/diffusers/utils/outputs.py @@ -59,10 +59,17 @@ class BaseOutput(OrderedDict): if not len(class_fields): raise ValueError(f"{self.__class__.__name__} has no fields.") - for field in class_fields: - v = getattr(self, field.name) - if v is not None: - self[field.name] = v + first_field = getattr(self, class_fields[0].name) + other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) + + if other_fields_are_none and isinstance(first_field, dict): + for key, value in first_field.items(): + self[key] = value + else: + for field in class_fields: + v = getattr(self, field.name) + if v is not None: + self[field.name] = v def __delitem__(self, *args, **kwargs): raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000..e116f40e64 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,44 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# tests directory-specific settings - this file is run automatically +# by pytest before any tests are run + +import sys +import warnings +from os.path import abspath, dirname, join + + +# allow having multiple repository checkouts and not needing to remember to rerun +# 'pip install -e .[dev]' when switching between checkouts and running tests. +git_repo_path = abspath(join(dirname(dirname(__file__)), "src")) +sys.path.insert(1, git_repo_path) + +# silence FutureWarning warnings in tests since often we can't act on them until +# they become normal warnings - i.e. the tests still need to test the current functionality +warnings.simplefilter(action="ignore", category=FutureWarning) + + +def pytest_addoption(parser): + from diffusers.testing_utils import pytest_addoption_shared + + pytest_addoption_shared(parser) + + +def pytest_terminal_summary(terminalreporter): + from diffusers.testing_utils import pytest_terminal_summary_main + + make_reports = terminalreporter.config.getoption("--make-reports") + if make_reports: + pytest_terminal_summary_main(terminalreporter, id=make_reports) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 1e98fc9de7..b0d00b863a 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -246,3 +246,21 @@ class ModelTesterMixin: outputs_tuple = model(**inputs_dict, return_dict=False) recursive_check(outputs_tuple, outputs_dict) + + def test_enable_disable_gradient_checkpointing(self): + if not self.model_class._supports_gradient_checkpointing: + return # Skip test if model does not support gradient checkpointing + + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + + # at init model should have gradient checkpointing disabled + model = self.model_class(**init_dict) + self.assertFalse(model.is_gradient_checkpointing) + + # check enable works + model.enable_gradient_checkpointing() + self.assertTrue(model.is_gradient_checkpointing) + + # check disable works + model.disable_gradient_checkpointing() + self.assertFalse(model.is_gradient_checkpointing) diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index b16a4e1c44..80055c1a10 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -18,8 +18,8 @@ import unittest import torch -from diffusers import UNet2DModel -from diffusers.testing_utils import floats_tensor, torch_device +from diffusers import UNet2DConditionModel, UNet2DModel +from diffusers.testing_utils import floats_tensor, slow, torch_device from .test_modeling_common import ModelTesterMixin @@ -159,6 +159,82 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) +class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): + model_class = UNet2DConditionModel + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 4 + sizes = (32, 32) + + noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + time_step = torch.tensor([10]).to(torch_device) + encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device) + + return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + + @property + def input_shape(self): + return (4, 32, 32) + + @property + def output_shape(self): + return (4, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": (32, 64), + "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"), + "up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"), + "cross_attention_dim": 32, + "attention_head_dim": 8, + "out_channels": 4, + "in_channels": 4, + "layers_per_block": 2, + "sample_size": 32, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + + out = model(**inputs_dict).sample + # run the backwards pass on the model. For backwards pass, for simplicity purpose, + # we won't calculate the loss and rather backprop on out.sum() + model.zero_grad() + out.sum().backward() + + # now we save the output and parameter gradients that we will use for comparison purposes with + # the non-checkpointed run. + output_not_checkpointed = out.data.clone() + grad_not_checkpointed = {} + for name, param in model.named_parameters(): + grad_not_checkpointed[name] = param.grad.data.clone() + + model.enable_gradient_checkpointing() + out = model(**inputs_dict).sample + # run the backwards pass on the model. For backwards pass, for simplicity purpose, + # we won't calculate the loss and rather backprop on out.sum() + model.zero_grad() + out.sum().backward() + + # now we save the output and parameter gradients that we will use for comparison purposes with + # the non-checkpointed run. + output_checkpointed = out.data.clone() + grad_checkpointed = {} + for name, param in model.named_parameters(): + grad_checkpointed[name] = param.grad.data.clone() + + # compare the output and parameters gradients + self.assertTrue((output_checkpointed == output_not_checkpointed).all()) + for name in grad_checkpointed: + self.assertTrue(torch.allclose(grad_checkpointed[name], grad_not_checkpointed[name], atol=5e-5)) + + # TODO(Patrick) - Re-add this test after having cleaned up LDM # def test_output_pretrained_spatial_transformer(self): # model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy-spatial") @@ -231,6 +307,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): inputs_dict = self.dummy_input return init_dict, inputs_dict + @slow def test_from_pretrained_hub(self): model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True) self.assertIsNotNone(model) @@ -244,6 +321,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): assert image is not None, "Make sure output is not None" + @slow def test_output_pretrained_ve_mid(self): model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256") model.to(torch_device) diff --git a/tests/test_outputs.py b/tests/test_outputs.py new file mode 100644 index 0000000000..3c3054c885 --- /dev/null +++ b/tests/test_outputs.py @@ -0,0 +1,60 @@ +import unittest +from dataclasses import dataclass +from typing import List, Union + +import numpy as np + +import PIL.Image +from diffusers.utils.outputs import BaseOutput + + +@dataclass +class CustomOutput(BaseOutput): + images: Union[List[PIL.Image.Image], np.ndarray] + + +class ConfigTester(unittest.TestCase): + def test_outputs_single_attribute(self): + outputs = CustomOutput(images=np.random.rand(1, 3, 4, 4)) + + # check every way of getting the attribute + assert isinstance(outputs.images, np.ndarray) + assert outputs.images.shape == (1, 3, 4, 4) + assert isinstance(outputs["images"], np.ndarray) + assert outputs["images"].shape == (1, 3, 4, 4) + assert isinstance(outputs[0], np.ndarray) + assert outputs[0].shape == (1, 3, 4, 4) + + # test with a non-tensor attribute + outputs = CustomOutput(images=[PIL.Image.new("RGB", (4, 4))]) + + # check every way of getting the attribute + assert isinstance(outputs.images, list) + assert isinstance(outputs.images[0], PIL.Image.Image) + assert isinstance(outputs["images"], list) + assert isinstance(outputs["images"][0], PIL.Image.Image) + assert isinstance(outputs[0], list) + assert isinstance(outputs[0][0], PIL.Image.Image) + + def test_outputs_dict_init(self): + # test output reinitialization with a `dict` for compatibility with `accelerate` + outputs = CustomOutput({"images": np.random.rand(1, 3, 4, 4)}) + + # check every way of getting the attribute + assert isinstance(outputs.images, np.ndarray) + assert outputs.images.shape == (1, 3, 4, 4) + assert isinstance(outputs["images"], np.ndarray) + assert outputs["images"].shape == (1, 3, 4, 4) + assert isinstance(outputs[0], np.ndarray) + assert outputs[0].shape == (1, 3, 4, 4) + + # test with a non-tensor attribute + outputs = CustomOutput({"images": [PIL.Image.new("RGB", (4, 4))]}) + + # check every way of getting the attribute + assert isinstance(outputs.images, list) + assert isinstance(outputs.images[0], PIL.Image.Image) + assert isinstance(outputs["images"], list) + assert isinstance(outputs["images"][0], PIL.Image.Image) + assert isinstance(outputs[0], list) + assert isinstance(outputs[0][0], PIL.Image.Image) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 102a55a93e..dddf42bd03 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -46,11 +46,10 @@ from diffusers import ( UNet2DModel, VQModel, ) -from diffusers.modeling_utils import WEIGHTS_NAME from diffusers.pipeline_utils import DiffusionPipeline from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.testing_utils import floats_tensor, load_image, slow, torch_device -from diffusers.utils import CONFIG_NAME +from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME from PIL import Image from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer @@ -1105,7 +1104,7 @@ class PipelineTesterMixin(unittest.TestCase): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 256, 256, 3) - expected_slice = np.array([0.26815, 0.1581, 0.2658, 0.23248, 0.1550, 0.2539, 0.1131, 0.1024, 0.0837]) + expected_slice = np.array([0.578, 0.5811, 0.5924, 0.5809, 0.587, 0.5886, 0.5861, 0.5802, 0.586]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow @@ -1325,14 +1324,58 @@ class PipelineTesterMixin(unittest.TestCase): assert image.shape == (512, 512, 3) assert np.abs(expected_image - image).max() < 1e-2 + @slow + @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") + def test_stable_diffusion_inpaint_pipeline_k_lms(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" + ) + expected_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/red_cat_sitting_on_a_park_bench_k_lms.png" + ) + expected_image = np.array(expected_image, dtype=np.float32) / 255.0 + + lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + + model_id = "CompVis/stable-diffusion-v1-4" + pipe = StableDiffusionInpaintPipeline.from_pretrained( + model_id, + scheduler=lms, + safety_checker=self.dummy_safety_checker, + use_auth_token=True, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "A red cat sitting on a park bench" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=prompt, + init_image=init_image, + mask_image=mask_image, + strength=0.75, + guidance_scale=7.5, + generator=generator, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 1e-2 + @slow def test_stable_diffusion_onnx(self): - from scripts.convert_stable_diffusion_checkpoint_to_onnx import convert_models - - with tempfile.TemporaryDirectory() as tmpdirname: - convert_models("CompVis/stable-diffusion-v1-4", tmpdirname, opset=14) - - sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(tmpdirname, provider="CUDAExecutionProvider") + sd_pipe = StableDiffusionOnnxPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", revision="onnx", provider="CUDAExecutionProvider", use_auth_token=True + ) prompt = "A painting of a squirrel eating a burger" np.random.seed(0) diff --git a/utils/check_tf_ops.py b/utils/check_tf_ops.py deleted file mode 100644 index a3b9593bb2..0000000000 --- a/utils/check_tf_ops.py +++ /dev/null @@ -1,101 +0,0 @@ -# coding=utf-8 -# Copyright 2020 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import json -import os - -from tensorflow.core.protobuf.saved_model_pb2 import SavedModel - - -# All paths are set with the intent you should run this script from the root of the repo with the command -# python utils/check_copies.py -REPO_PATH = "." - -# Internal TensorFlow ops that can be safely ignored (mostly specific to a saved model) -INTERNAL_OPS = [ - "Assert", - "AssignVariableOp", - "EmptyTensorList", - "MergeV2Checkpoints", - "ReadVariableOp", - "ResourceGather", - "RestoreV2", - "SaveV2", - "ShardedFilename", - "StatefulPartitionedCall", - "StaticRegexFullMatch", - "VarHandleOp", -] - - -def onnx_compliance(saved_model_path, strict, opset): - saved_model = SavedModel() - onnx_ops = [] - - with open(os.path.join(REPO_PATH, "utils", "tf_ops", "onnx.json")) as f: - onnx_opsets = json.load(f)["opsets"] - - for i in range(1, opset + 1): - onnx_ops.extend(onnx_opsets[str(i)]) - - with open(saved_model_path, "rb") as f: - saved_model.ParseFromString(f.read()) - - model_op_names = set() - - # Iterate over every metagraph in case there is more than one (a saved model can contain multiple graphs) - for meta_graph in saved_model.meta_graphs: - # Add operations in the graph definition - model_op_names.update(node.op for node in meta_graph.graph_def.node) - - # Go through the functions in the graph definition - for func in meta_graph.graph_def.library.function: - # Add operations in each function - model_op_names.update(node.op for node in func.node_def) - - # Convert to list, sorted if you want - model_op_names = sorted(model_op_names) - incompatible_ops = [] - - for op in model_op_names: - if op not in onnx_ops and op not in INTERNAL_OPS: - incompatible_ops.append(op) - - if strict and len(incompatible_ops) > 0: - raise Exception(f"Found the following incompatible ops for the opset {opset}:\n" + incompatible_ops) - elif len(incompatible_ops) > 0: - print(f"Found the following incompatible ops for the opset {opset}:") - print(*incompatible_ops, sep="\n") - else: - print(f"The saved model {saved_model_path} can properly be converted with ONNX.") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--saved_model_path", help="Path of the saved model to check (the .pb file).") - parser.add_argument( - "--opset", default=12, type=int, help="The ONNX opset against which the model has to be tested." - ) - parser.add_argument( - "--framework", choices=["onnx"], default="onnx", help="Frameworks against which to test the saved model." - ) - parser.add_argument( - "--strict", action="store_true", help="Whether make the checking strict (raise errors) or not (raise warnings)" - ) - args = parser.parse_args() - - if args.framework == "onnx": - onnx_compliance(args.saved_model_path, args.strict, args.opset) diff --git a/utils/get_modified_files.py b/utils/get_modified_files.py new file mode 100644 index 0000000000..44c60e96ab --- /dev/null +++ b/utils/get_modified_files.py @@ -0,0 +1,34 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# this script reports modified .py files under the desired list of top-level sub-dirs passed as a list of arguments, e.g.: +# python ./utils/get_modified_files.py utils src tests examples +# +# it uses git to find the forking point and which files were modified - i.e. files not under git won't be considered +# since the output of this script is fed into Makefile commands it doesn't print a newline after the results + +import re +import subprocess +import sys + + +fork_point_sha = subprocess.check_output("git merge-base main HEAD".split()).decode("utf-8") +modified_files = subprocess.check_output(f"git diff --name-only {fork_point_sha}".split()).decode("utf-8").split() + +joined_dirs = "|".join(sys.argv[1:]) +regex = re.compile(rf"^({joined_dirs}).*?\.py$") + +relevant_modified_files = [x for x in modified_files if regex.match(x)] +print(" ".join(relevant_modified_files), end="")