1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Merge branch 'main' of https://github.com/huggingface/diffusers into dreambooth-example

This commit is contained in:
patil-suraj
2022-09-26 15:08:12 +02:00
67 changed files with 3682 additions and 394 deletions

View File

@@ -41,4 +41,15 @@ jobs:
- name: Run all non-slow selected tests on CPU
run: |
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile -s tests/
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=tests_torch_cpu tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_cpu_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
name: pr_torch_test_reports
path: reports

View File

@@ -49,4 +49,66 @@ jobs:
env:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s tests/
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=tests_torch_gpu tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_gpu_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
name: torch_test_reports
path: reports
run_examples_single_gpu:
name: Examples tests
strategy:
fail-fast: false
matrix:
machine_type: [ single-gpu ]
runs-on: [ self-hosted, docker-gpu, '${{ matrix.machine_type }}' ]
container:
image: nvcr.io/nvidia/pytorch:22.07-py3
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
run: |
nvidia-smi
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip uninstall -y torch torchvision torchtext
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu116
python -m pip install -e .[quality,test,training]
- name: Environment
run: |
python utils/print_env.py
- name: Run example tests on GPU
env:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_gpu examples/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/examples_torch_gpu_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
name: examples_test_reports
path: reports

View File

@@ -4,8 +4,9 @@
[default.extend-identifiers]
[default.extend-words]
NIN_="NIN" # NIN is used in scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py
NIN="NIN" # NIN is used in scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py
nd="np" # nd may be np (numpy)
parms="parms" # parms is used in scripts/convert_original_stable_diffusion_to_diffusers.py
[files]

View File

@@ -45,3 +45,21 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
## AutoencoderKL
[[autodoc]] AutoencoderKL
## FlaxModelMixin
[[autodoc]] FlaxModelMixin
## FlaxUNet2DConditionOutput
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput
## FlaxUNet2DConditionModel
[[autodoc]] FlaxUNet2DConditionModel
## FlaxDecoderOutput
[[autodoc]] models.vae_flax.FlaxDecoderOutput
## FlaxAutoencoderKLOutput
[[autodoc]] models.vae_flax.FlaxAutoencoderKLOutput
## FlaxAutoencoderKL
[[autodoc]] FlaxAutoencoderKL

View File

@@ -53,7 +53,7 @@ available a colab notebook to directly try them out.
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stochatic_karras_ve](./stochatic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
| [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.
@@ -85,7 +85,7 @@ not be used for training. If you want to store the gradients during the forward
We are more than happy about any contribution to the officially supported pipelines 🤗. We aspire
all of our pipelines to be **self-contained**, **easy-to-tweak**, **beginner-friendly** and for **one-purpose-only**.
- **Self-contained**: A pipeline shall be as self-contained as possible. More specifically, this means that all functionality should be either directly defined in the pipeline file iteslf, should be inherited from (and only from) the [`DiffusionPipeline` class](.../diffusion_pipeline) or be directly attached to the model and scheduler components of the pipeline.
- **Self-contained**: A pipeline shall be as self-contained as possible. More specifically, this means that all functionality should be either directly defined in the pipeline file itself, should be inherited from (and only from) the [`DiffusionPipeline` class](.../diffusion_pipeline) or be directly attached to the model and scheduler components of the pipeline.
- **Easy-to-use**: Pipelines should be extremely easy to use - one should be able to load the pipeline and
use it for its designated task, *e.g.* text-to-image generation, in just a couple of lines of code. Most
logic including pre-processing, an unrolled diffusion loop, and post-processing should all happen inside the `__call__` method.

View File

@@ -44,6 +44,6 @@ available a colab notebook to directly try them out.
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stochatic_karras_ve](./api/pipelines/stochatic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.

View File

@@ -27,6 +27,6 @@ pip install diffusers
### Schedulers
### Pipeliens
### Pipelines

View File

@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
# Unonditional Image Generation
# Unconditional Image Generation
The [`DiffusionPipeline`] is the easiest way to use a pre-trained diffusion system for inference

View File

@@ -2,5 +2,6 @@
**Community** examples consist of both inference and training examples that have been added by the community.
| Example | Description | Author | |
|:----------|:-------------|:-------------|------:|
| Example | Description | Author | Colab |
|:----------|:----------------------|:-----------------|----------:|
| CLIP Guided Stable Diffusion | Doing CLIP guidance for text to image generation with Stable Diffusion| [Suraj Patil](https://github.com/patil-suraj/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb) |

View File

@@ -0,0 +1,321 @@
import inspect
from typing import List, Optional, Union
import torch
from torch import nn
from torch.nn import functional as F
from diffusers import AutoencoderKL, DiffusionPipeline, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
from torchvision import transforms
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
class MakeCutouts(nn.Module):
def __init__(self, cut_size, cut_power=1.0):
super().__init__()
self.cut_size = cut_size
self.cut_power = cut_power
def forward(self, pixel_values, num_cutouts):
sideY, sideX = pixel_values.shape[2:4]
max_size = min(sideX, sideY)
min_size = min(sideX, sideY, self.cut_size)
cutouts = []
for _ in range(num_cutouts):
size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size)
offsetx = torch.randint(0, sideX - size + 1, ())
offsety = torch.randint(0, sideY - size + 1, ())
cutout = pixel_values[:, :, offsety : offsety + size, offsetx : offsetx + size]
cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
return torch.cat(cutouts)
def spherical_dist_loss(x, y):
x = F.normalize(x, dim=-1)
y = F.normalize(y, dim=-1)
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
def set_requires_grad(model, value):
for param in model.parameters():
param.requires_grad = value
class CLIPGuidedStableDiffusion(DiffusionPipeline):
"""CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000
- https://github.com/Jack000/glid-3-xl
- https://github.dev/crowsonkb/k-diffusion
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
clip_model: CLIPModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler],
feature_extractor: CLIPFeatureExtractor,
):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(
vae=vae,
text_encoder=text_encoder,
clip_model=clip_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
feature_extractor=feature_extractor,
)
self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
self.make_cutouts = MakeCutouts(feature_extractor.size)
set_requires_grad(self.text_encoder, False)
set_requires_grad(self.clip_model, False)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
self.enable_attention_slicing(None)
def freeze_vae(self):
set_requires_grad(self.vae, False)
def unfreeze_vae(self):
set_requires_grad(self.vae, True)
def freeze_unet(self):
set_requires_grad(self.unet, False)
def unfreeze_unet(self):
set_requires_grad(self.unet, True)
@torch.enable_grad()
def cond_fn(
self,
latents,
timestep,
index,
text_embeddings,
noise_pred_original,
text_embeddings_clip,
clip_guidance_scale,
num_cutouts,
use_cutouts=True,
):
latents = latents.detach().requires_grad_()
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[index]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latents / ((sigma**2 + 1) ** 0.5)
else:
latent_model_input = latents
# predict the noise residual
noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
if isinstance(self.scheduler, PNDMScheduler):
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
beta_prod_t = 1 - alpha_prod_t
# compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
fac = torch.sqrt(beta_prod_t)
sample = pred_original_sample * (fac) + latents * (1 - fac)
elif isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[index]
sample = latents - sigma * noise_pred
else:
raise ValueError(f"scheduler type {type(self.scheduler)} not supported")
sample = 1 / 0.18215 * sample
image = self.vae.decode(sample).sample
image = (image / 2 + 0.5).clamp(0, 1)
if use_cutouts:
image = self.make_cutouts(image, num_cutouts)
else:
image = transforms.Resize(self.feature_extractor.size)(image)
image = self.normalize(image)
image_embeddings_clip = self.clip_model.get_image_features(image).float()
image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
if use_cutouts:
dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip)
dists = dists.view([num_cutouts, sample.shape[0], -1])
loss = dists.sum(2).mean(0).sum() * clip_guidance_scale
else:
loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale
grads = -torch.autograd.grad(loss, latents)[0]
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents.detach() + grads * (sigma**2)
noise_pred = noise_pred_original
else:
noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads
return noise_pred, latents
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
height: Optional[int] = 512,
width: Optional[int] = 512,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
clip_guidance_scale: Optional[float] = 100,
clip_prompt: Optional[Union[str, List[str]]] = None,
num_cutouts: Optional[int] = 4,
use_cutouts: Optional[bool] = True,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# get prompt text embeddings
text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
if clip_guidance_scale > 0:
if clip_prompt is not None:
clip_text_input = self.tokenizer(
clip_prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
).input_ids.to(self.device)
else:
clip_text_input = text_input.input_ids.to(self.device)
text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
max_length = text_input.input_ids.shape[-1]
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# get the initial random noise unless the user supplied it
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_device = "cpu" if self.device.type == "mps" else self.device
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
if latents is None:
latents = torch.randn(
latents_shape,
generator=generator,
device=latents_device,
)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device)
# set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
extra_set_kwargs = {}
if accepts_offset:
extra_set_kwargs["offset"] = 1
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[i]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
# # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform classifier free guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# perform clip guidance
if clip_guidance_scale > 0:
text_embeddings_for_guidance = (
text_embeddings.chunk(2)[0] if do_classifier_free_guidance else text_embeddings
)
noise_pred, latents = self.cond_fn(
latents,
t,
i,
text_embeddings_for_guidance,
noise_pred,
text_embeddings_clip,
clip_guidance_scale,
num_cutouts,
use_cutouts,
)
# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, i, latents).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image, None)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)

45
examples/conftest.py Normal file
View File

@@ -0,0 +1,45 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# tests directory-specific settings - this file is run automatically
# by pytest before any tests are run
import sys
import warnings
from os.path import abspath, dirname, join
# allow having multiple repository checkouts and not needing to remember to rerun
# 'pip install -e .[dev]' when switching between checkouts and running tests.
git_repo_path = abspath(join(dirname(dirname(dirname(__file__))), "src"))
sys.path.insert(1, git_repo_path)
# silence FutureWarning warnings in tests since often we can't act on them until
# they become normal warnings - i.e. the tests still need to test the current functionality
warnings.simplefilter(action="ignore", category=FutureWarning)
def pytest_addoption(parser):
from diffusers.testing_utils import pytest_addoption_shared
pytest_addoption_shared(parser)
def pytest_terminal_summary(terminalreporter):
from diffusers.testing_utils import pytest_terminal_summary_main
make_reports = terminalreporter.config.getoption("--make-reports")
if make_reports:
pytest_terminal_summary_main(terminalreporter, id=make_reports)

124
examples/test_examples.py Normal file
View File

@@ -0,0 +1,124 @@
# coding=utf-8
# Copyright 2022 HuggingFace Inc..
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import shutil
import subprocess
import sys
import tempfile
import unittest
from typing import List
from accelerate.utils import write_basic_config
from diffusers.testing_utils import slow
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
# These utils relate to ensuring the right error message is received when running scripts
class SubprocessCallException(Exception):
pass
def run_command(command: List[str], return_stdout=False):
"""
Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
if an error occured while running `command`
"""
try:
output = subprocess.check_output(command, stderr=subprocess.STDOUT)
if return_stdout:
if hasattr(output, "decode"):
output = output.decode("utf-8")
return output
except subprocess.CalledProcessError as e:
raise SubprocessCallException(
f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
) from e
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class ExamplesTestsAccelerate(unittest.TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._tmpdir = tempfile.mkdtemp()
cls.configPath = os.path.join(cls._tmpdir, "default_config.yml")
write_basic_config(save_location=cls.configPath)
cls._launch_args = ["accelerate", "launch", "--config_file", cls.configPath]
@classmethod
def tearDownClass(cls):
super().tearDownClass()
shutil.rmtree(cls._tmpdir)
@slow
def test_train_unconditional(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/unconditional_image_generation/train_unconditional.py
--dataset_name huggan/few-shot-aurora
--resolution 64
--output_dir {tmpdir}
--train_batch_size 4
--num_epochs 1
--gradient_accumulation_steps 1
--learning_rate 1e-3
--lr_warmup_steps 5
--mixed_precision fp16
""".split()
run_command(self._launch_args + test_args, return_stdout=True)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
# logging test
self.assertTrue(len(os.listdir(os.path.join(tmpdir, "logs", "train_unconditional"))) > 0)
@slow
def test_textual_inversion(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/textual_inversion/textual_inversion.py
--pretrained_model_name_or_path CompVis/stable-diffusion-v1-4
--use_auth_token
--train_data_dir docs/source/imgs
--learnable_property object
--placeholder_token <cat-toy>
--initializer_token toy
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 2
--max_train_steps 10
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--mixed_precision fp16
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.bin")))

View File

@@ -13,11 +13,14 @@
# limitations under the License.
import argparse
import os
import shutil
from pathlib import Path
import torch
from torch.onnx import export
import onnx
from diffusers import StableDiffusionOnnxPipeline, StableDiffusionPipeline
from diffusers.onnx_utils import OnnxRuntimeModel
from packaging import version
@@ -92,10 +95,11 @@ def convert_models(model_path: str, output_path: str, opset: int):
)
# UNET
unet_path = output_path / "unet" / "model.onnx"
onnx_export(
pipeline.unet,
model_args=(torch.randn(2, 4, 64, 64), torch.LongTensor([0, 1]), torch.randn(2, 77, 768), False),
output_path=output_path / "unet" / "model.onnx",
output_path=unet_path,
ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"],
output_names=["out_sample"], # has to be different from "sample" for correct tracing
dynamic_axes={
@@ -106,6 +110,21 @@ def convert_models(model_path: str, output_path: str, opset: int):
opset=opset,
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split
)
unet_model_path = str(unet_path.absolute().as_posix())
unet_dir = os.path.dirname(unet_model_path)
unet = onnx.load(unet_model_path)
# clean up existing tensor files
shutil.rmtree(unet_dir)
os.mkdir(unet_dir)
# collate external tensor files into one
onnx.save_model(
unet,
unet_model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location="weights.pb",
convert_attribute=False,
)
# VAE ENCODER
vae_encoder = pipeline.vae

View File

@@ -77,7 +77,7 @@ from setuptools import find_packages, setup
# 1. all dependencies should be listed here with their version requirements if any
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
_deps = [
"Pillow",
"Pillow<10.0", # keep the PIL.Image.Resampling deprecation away
"accelerate>=0.11.0",
"black==22.8",
"datasets",
@@ -90,8 +90,9 @@ _deps = [
"isort>=5.5.4",
"jax>=0.2.8,!=0.3.2,<=0.3.6",
"jaxlib>=0.1.65,<=0.3.6",
"modelcards==0.1.4",
"modelcards>=0.1.4",
"numpy",
"onnxruntime-gpu",
"pytest",
"pytest-timeout",
"pytest-xdist",
@@ -100,6 +101,7 @@ _deps = [
"requests",
"tensorboard",
"torch>=1.4",
"torchvision",
"transformers>=4.21.0",
]
@@ -171,10 +173,19 @@ extras = {}
extras = {}
extras["quality"] = ["black==22.8", "isort>=5.5.4", "flake8>=3.8.3", "hf-doc-builder"]
extras["docs"] = ["hf-doc-builder"]
extras["training"] = ["accelerate", "datasets", "tensorboard", "modelcards"]
extras["test"] = ["datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "transformers"]
extras["quality"] = deps_list("black", "isort", "flake8", "hf-doc-builder")
extras["docs"] = deps_list("hf-doc-builder")
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
extras["test"] = deps_list(
"datasets",
"onnxruntime-gpu",
"pytest",
"pytest-timeout",
"pytest-xdist",
"scipy",
"torchvision",
"transformers"
)
extras["torch"] = deps_list("torch")
if os.name == "nt": # windows

View File

@@ -65,6 +65,8 @@ else:
if is_flax_available():
from .modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.vae_flax import FlaxAutoencoderKL
from .pipeline_flax_utils import FlaxDiffusionPipeline
from .schedulers import (
FlaxDDIMScheduler,
FlaxDDPMScheduler,
@@ -75,3 +77,8 @@ if is_flax_available():
)
else:
from .utils.dummy_flax_objects import * # noqa F403
if is_flax_available() and is_transformers_available():
from .pipelines import FlaxStableDiffusionPipeline
else:
from .utils.dummy_flax_and_transformers_objects import * # noqa F403

View File

@@ -154,15 +154,25 @@ class ConfigMixin:
"""
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
# Allow dtype to be specified on initialization
if "dtype" in unused_kwargs:
init_dict["dtype"] = unused_kwargs.pop("dtype")
# Return model and optionally state and/or unused_kwargs
model = cls(**init_dict)
return_tuple = (model,)
# Flax schedulers have a state, so return it.
if cls.__name__.startswith("Flax") and hasattr(model, "create_state") and getattr(model, "has_state", False):
state = model.create_state()
return_tuple += (state,)
if return_unused_kwargs:
return model, unused_kwargs
return return_tuple + (unused_kwargs,)
else:
return model
return return_tuple if len(return_tuple) > 1 else model
@classmethod
def get_config_dict(
@@ -272,7 +282,7 @@ class ConfigMixin:
# remove general kwargs if present in dict
if "kwargs" in expected_keys:
expected_keys.remove("kwargs")
# remove flax interal keys
# remove flax internal keys
if hasattr(cls, "_flax_internal_args"):
for arg in cls._flax_internal_args:
expected_keys.remove(arg)
@@ -446,6 +456,9 @@ def flax_register_to_config(cls):
# Make sure init_kwargs override default kwargs
new_kwargs = {**default_kwargs, **init_kwargs}
# dtype should be part of `init_kwargs`, but not `new_kwargs`
if "dtype" in new_kwargs:
new_kwargs.pop("dtype")
# Get positional arguments aligned with kwargs
for i, arg in enumerate(args):

View File

@@ -2,7 +2,7 @@
# 1. modify the `_deps` dict in setup.py
# 2. run `make deps_table_update``
deps = {
"Pillow": "Pillow",
"Pillow": "Pillow<10.0",
"accelerate": "accelerate>=0.11.0",
"black": "black==22.8",
"datasets": "datasets",
@@ -15,8 +15,10 @@ deps = {
"isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2,<=0.3.6",
"jaxlib": "jaxlib>=0.1.65,<=0.3.6",
"modelcards": "modelcards==0.1.4",
"modelcards": "modelcards>=0.1.4",
"numpy": "numpy",
"onnxruntime": "onnxruntime",
"onnxruntime-gpu": "onnxruntime-gpu",
"pytest": "pytest",
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
@@ -25,5 +27,6 @@ deps = {
"requests": "requests",
"tensorboard": "tensorboard",
"torch": "torch>=1.4",
"torchvision": "torchvision",
"transformers": "transformers>=4.21.0",
}

View File

@@ -0,0 +1,117 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch - Flax general utilities."""
import re
import jax.numpy as jnp
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey
from .utils import logging
logger = logging.get_logger(__name__)
def rename_key(key):
regex = r"\w+[.]\d+"
pats = re.findall(regex, key)
for pat in pats:
key = key.replace(pat, "_".join(pat.split(".")))
return key
#####################
# PyTorch => Flax #
#####################
# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
# conv norm or layer norm
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
if (
any("norm" in str_ for str_ in pt_tuple_key)
and (pt_tuple_key[-1] == "bias")
and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
):
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
return renamed_pt_tuple_key, pt_tensor
elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
return renamed_pt_tuple_key, pt_tensor
# embedding
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
return renamed_pt_tuple_key, pt_tensor
# conv layer
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
return renamed_pt_tuple_key, pt_tensor
# linear layer
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
if pt_tuple_key[-1] == "weight":
pt_tensor = pt_tensor.T
return renamed_pt_tuple_key, pt_tensor
# old PyTorch layer norm weight
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
if pt_tuple_key[-1] == "gamma":
return renamed_pt_tuple_key, pt_tensor
# old PyTorch layer norm bias
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
if pt_tuple_key[-1] == "beta":
return renamed_pt_tuple_key, pt_tensor
return pt_tuple_key, pt_tensor
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
# Step 1: Convert pytorch tensor to numpy
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
# Step 2: Since the model is stateless, get random Flax params
random_flax_params = flax_model.init_weights(PRNGKey(init_key))
random_flax_state_dict = flatten_dict(random_flax_params)
flax_state_dict = {}
# Need to change some parameters name to match Flax names
for pt_key, pt_tensor in pt_state_dict.items():
renamed_pt_key = rename_key(pt_key)
pt_tuple_key = tuple(renamed_pt_key.split("."))
# Correctly rename weight parameters
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict)
if flax_key in random_flax_state_dict:
if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
raise ValueError(
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
)
# also add unexpected weight so that warning is thrown
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
return unflatten_dict(flax_state_dict)

View File

@@ -27,12 +27,18 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError
from .modeling_utils import WEIGHTS_NAME
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
from .modeling_utils import load_state_dict
from .utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
FLAX_WEIGHTS_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
WEIGHTS_NAME,
logging,
)
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
logger = logging.get_logger(__name__)
@@ -45,7 +51,7 @@ class FlaxModelMixin:
"""
config_name = CONFIG_NAME
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
_flax_internal_args = ["name", "parent"]
_flax_internal_args = ["name", "parent", "dtype"]
@classmethod
def _from_config(cls, config, **kwargs):
@@ -245,6 +251,8 @@ class FlaxModelMixin:
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
from_pt (`bool`, *optional*, defaults to `False`):
Load the model weights from a PyTorch checkpoint save file.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
@@ -272,6 +280,7 @@ class FlaxModelMixin:
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
from_pt = kwargs.pop("from_pt", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
@@ -294,32 +303,44 @@ class FlaxModelMixin:
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
# model args
dtype=dtype,
**kwargs,
)
# Load model
if os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
pretrained_path_with_subfolder = (
pretrained_model_name_or_path
if subfolder is None
else os.path.join(pretrained_model_name_or_path, subfolder)
)
if os.path.isdir(pretrained_path_with_subfolder):
if from_pt:
if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} "
)
model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)):
# Load from a Flax checkpoint
model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
# At this stage we don't have a weight file so we will raise an error.
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)
# Check if pytorch weights exist instead
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
raise EnvironmentError(
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
"but there is a file for PyTorch weights."
f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
" using `from_pt=True`."
)
else:
raise EnvironmentError(
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
f"{pretrained_model_name_or_path}."
f"{pretrained_path_with_subfolder}."
)
else:
try:
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=FLAX_WEIGHTS_NAME,
filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
@@ -369,25 +390,32 @@ class FlaxModelMixin:
f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
)
try:
with open(model_file, "rb") as state_f:
state = from_bytes(cls, state_f.read())
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
if from_pt:
# Step 1: Get the pytorch file
pytorch_model_file = load_state_dict(model_file)
# Step 2: Convert the weights
state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model)
else:
try:
with open(model_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please"
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
" folder you cloned."
)
else:
raise ValueError from e
except (UnicodeDecodeError, ValueError):
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
# make sure all arrays are stored as jnp.ndarray
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
# https://github.com/google/flax/issues/1261
with open(model_file, "rb") as state_f:
state = from_bytes(cls, state_f.read())
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
try:
with open(model_file) as f:
if f.read().startswith("version"):
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please"
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
" folder you cloned."
)
else:
raise ValueError from e
except (UnicodeDecodeError, ValueError):
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
# make sure all arrays are stored as jnp.ndarray
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
# https://github.com/google/flax/issues/1261
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
# flatten dicts
@@ -408,7 +436,7 @@ class FlaxModelMixin:
)
cls._missing_keys = missing_keys
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys = []
for key in state.keys():

View File

@@ -15,6 +15,7 @@
# limitations under the License.
import os
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
import torch
@@ -24,10 +25,7 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, logging
logger = logging.get_logger(__name__)
@@ -124,10 +122,42 @@ class ModelMixin(torch.nn.Module):
"""
config_name = CONFIG_NAME
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
_supports_gradient_checkpointing = False
def __init__(self):
super().__init__()
@property
def is_gradient_checkpointing(self) -> bool:
"""
Whether gradient checkpointing is activated for this model or not.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
"""
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
def enable_gradient_checkpointing(self):
"""
Activates gradient checkpointing for the current model.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
"""
if not self._supports_gradient_checkpointing:
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
self.apply(partial(self._set_gradient_checkpointing, value=True))
def disable_gradient_checkpointing(self):
"""
Deactivates gradient checkpointing for the current model.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
"""
if self._supports_gradient_checkpointing:
self.apply(partial(self._set_gradient_checkpointing, value=False))
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],

View File

@@ -12,6 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .vae import AutoencoderKL, VQModel
from ..utils import is_flax_available, is_torch_available
if is_torch_available():
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .vae import AutoencoderKL, VQModel
if is_flax_available():
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
from .vae_flax import FlaxAutoencoderKL

View File

@@ -249,13 +249,15 @@ class CrossAttention(nn.Module):
return tensor
def forward(self, hidden_states, context=None, mask=None):
batch_size, sequence_length, dim = hidden_states.shape
batch_size, sequence_length, _ = hidden_states.shape
query = self.to_q(hidden_states)
context = context if context is not None else hidden_states
key = self.to_k(context)
value = self.to_v(context)
dim = query.shape[-1]
query = self.reshape_heads_to_batch_dim(query)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)

View File

@@ -17,6 +17,22 @@ import jax.numpy as jnp
class FlaxAttentionBlock(nn.Module):
r"""
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
Parameters:
query_dim (:obj:`int`):
Input hidden states dimension
heads (:obj:`int`, *optional*, defaults to 8):
Number of heads
dim_head (:obj:`int`, *optional*, defaults to 64):
Hidden states dimension inside each head
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
query_dim: int
heads: int = 8
dim_head: int = 64
@@ -32,7 +48,7 @@ class FlaxAttentionBlock(nn.Module):
self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out")
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
@@ -74,6 +90,23 @@ class FlaxAttentionBlock(nn.Module):
class FlaxBasicTransformerBlock(nn.Module):
r"""
A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in:
https://arxiv.org/abs/1706.03762
Parameters:
dim (:obj:`int`):
Inner hidden states dimension
n_heads (:obj:`int`):
Number of heads
d_head (:obj:`int`):
Hidden states dimension inside each head
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
dim: int
n_heads: int
d_head: int
@@ -82,9 +115,9 @@ class FlaxBasicTransformerBlock(nn.Module):
def setup(self):
# self attention
self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
# cross attention
self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
@@ -93,12 +126,12 @@ class FlaxBasicTransformerBlock(nn.Module):
def __call__(self, hidden_states, context, deterministic=True):
# self attention
residual = hidden_states
hidden_states = self.self_attn(self.norm1(hidden_states), deterministic=deterministic)
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
hidden_states = hidden_states + residual
# cross attention
residual = hidden_states
hidden_states = self.cross_attn(self.norm2(hidden_states), context, deterministic=deterministic)
hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
hidden_states = hidden_states + residual
# feed forward
@@ -110,6 +143,25 @@ class FlaxBasicTransformerBlock(nn.Module):
class FlaxSpatialTransformer(nn.Module):
r"""
A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
https://arxiv.org/pdf/1506.02025.pdf
Parameters:
in_channels (:obj:`int`):
Input number of channels
n_heads (:obj:`int`):
Number of heads
d_head (:obj:`int`):
Hidden states dimension inside each head
depth (:obj:`int`, *optional*, defaults to 1):
Number of transformers block
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int
n_heads: int
d_head: int
@@ -144,7 +196,6 @@ class FlaxSpatialTransformer(nn.Module):
def __call__(self, hidden_states, context, deterministic=True):
batch, height, width, channels = hidden_states.shape
# import ipdb; ipdb.set_trace()
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
@@ -163,18 +214,56 @@ class FlaxSpatialTransformer(nn.Module):
class FlaxGluFeedForward(nn.Module):
r"""
Flax module that encapsulates two Linear layers separated by a gated linear unit activation from:
https://arxiv.org/abs/2002.05202
Parameters:
dim (:obj:`int`):
Inner hidden states dimension
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
dim: int
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
def setup(self):
# The second linear layer needs to be called
# net_2 for now to match the index of the Sequential layer
self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
def __call__(self, hidden_states, deterministic=True):
hidden_states = self.net_0(hidden_states)
hidden_states = self.net_2(hidden_states)
return hidden_states
class FlaxGEGLU(nn.Module):
r"""
Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
https://arxiv.org/abs/2002.05202.
Parameters:
dim (:obj:`int`):
Input hidden states dimension
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
dim: int
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
def setup(self):
inner_dim = self.dim * 4
self.dense1 = nn.Dense(inner_dim * 2, dtype=self.dtype)
self.dense2 = nn.Dense(self.dim, dtype=self.dtype)
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
def __call__(self, hidden_states, deterministic=True):
hidden_states = self.dense1(hidden_states)
hidden_states = self.proj(hidden_states)
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
hidden_states = hidden_linear * nn.gelu(hidden_gelu)
hidden_states = self.dense2(hidden_states)
return hidden_states
return hidden_linear * nn.gelu(hidden_gelu)

View File

@@ -19,7 +19,7 @@ import jax.numpy as jnp
# This is like models.embeddings.get_timestep_embedding (PyTorch) but
# less general (only handles the case we currently need).
def get_sinusoidal_embeddings(timesteps, embedding_dim):
def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: float = 1):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
@@ -29,7 +29,7 @@ def get_sinusoidal_embeddings(timesteps, embedding_dim):
embeddings. :return: an [N x dim] tensor of positional embeddings.
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = math.log(10000) / (half_dim - freq_shift)
emb = jnp.exp(jnp.arange(half_dim) * -emb)
emb = timesteps[:, None] * emb[None, :]
emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1)
@@ -37,6 +37,15 @@ def get_sinusoidal_embeddings(timesteps, embedding_dim):
class FlaxTimestepEmbedding(nn.Module):
r"""
Time step Embedding Module. Learns embeddings for input time steps.
Args:
time_embed_dim (`int`, *optional*, defaults to `32`):
Time step embedding dimension
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
time_embed_dim: int = 32
dtype: jnp.dtype = jnp.float32
@@ -49,8 +58,16 @@ class FlaxTimestepEmbedding(nn.Module):
class FlaxTimesteps(nn.Module):
r"""
Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239
Args:
dim (`int`, *optional*, defaults to `32`):
Time step embedding dimension
"""
dim: int = 32
freq_shift: float = 1
@nn.compact
def __call__(self, timesteps):
return get_sinusoidal_embeddings(timesteps, self.dim)
return get_sinusoidal_embeddings(timesteps, self.dim, freq_shift=self.freq_shift)

View File

@@ -149,7 +149,6 @@ class FirUpsample2D(nn.Module):
stride = (factor, factor)
# Determine data dimensions.
stride = [1, 1, factor, factor]
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
output_padding = (
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
@@ -161,7 +160,7 @@ class FirUpsample2D(nn.Module):
# Transpose weights.
weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
weight = weight[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0)

View File

@@ -3,12 +3,21 @@ from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from .embeddings import TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
from .unet_blocks import (
CrossAttnDownBlock2D,
CrossAttnUpBlock2D,
DownBlock2D,
UNetMidBlock2DCrossAttn,
UpBlock2D,
get_down_block,
get_up_block,
)
@dataclass
@@ -54,6 +63,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
@@ -188,6 +199,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
if hasattr(block, "attentions") and block.attentions is not None:
block.set_attention_slice(slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.FloatTensor,
@@ -234,7 +249,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
sample, res_samples = downsample_block(
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass
from typing import Tuple, Union
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
@@ -19,7 +19,7 @@ from .unet_blocks_flax import (
)
@dataclass
@flax.struct.dataclass
class FlaxUNet2DConditionOutput(BaseOutput):
"""
Args:
@@ -39,10 +39,23 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the models (such as downloading or saving, etc.)
Also, this model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
sample_size (`int`, *optional*): The size of the input sample.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
sample_size (`int`, *optional*):
The size of the input sample.
in_channels (`int`, *optional*, defaults to 4):
The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4):
The number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D",
"FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D"
@@ -51,10 +64,14 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
"FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D"
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
cross_attention_dim (`int`, *optional*, defaults to 768): The dimension of the cross attention features.
dropout (`float`, *optional*, defaults to 0): Dropout probability for down, up and bottleneck blocks.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
attention_head_dim (`int`, *optional*, defaults to 8):
The dimension of the attention heads.
cross_attention_dim (`int`, *optional*, defaults to 768):
The dimension of the cross attention features.
dropout (`float`, *optional*, defaults to 0):
Dropout probability for down, up and bottleneck blocks.
"""
sample_size: int = 32
@@ -73,10 +90,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
cross_attention_dim: int = 1280
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
freq_shift: int = 0
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
# init input tensors
sample_shape = (1, self.sample_size, self.sample_size, self.in_channels)
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
timesteps = jnp.ones((1,), dtype=jnp.int32)
encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
@@ -100,7 +118,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
)
# time
self.time_proj = FlaxTimesteps(block_out_channels[0])
self.time_proj = FlaxTimesteps(block_out_channels[0], freq_shift=self.config.freq_shift)
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
# down
@@ -214,10 +232,17 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
When returning a tuple, the first element is the sample tensor.
"""
# 1. time
if not isinstance(timesteps, jnp.ndarray):
timesteps = jnp.array([timesteps], dtype=jnp.int32)
elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
timesteps = timesteps.astype(dtype=jnp.float32)
timesteps = jnp.expand_dims(timesteps, 0)
t_emb = self.time_proj(timesteps)
t_emb = self.time_embedding(t_emb)
# 2. pre-process
sample = jnp.transpose(sample, (0, 2, 3, 1))
sample = self.conv_in(sample)
# 3. down
@@ -251,6 +276,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
sample = self.conv_norm_out(sample)
sample = nn.silu(sample)
sample = self.conv_out(sample)
sample = jnp.transpose(sample, (0, 3, 1, 2))
if not return_dict:
return (sample,)

View File

@@ -527,6 +527,8 @@ class CrossAttnDownBlock2D(nn.Module):
else:
self.downsamplers = None
self.gradient_checkpointing = False
def set_attention_slice(self, slice_size):
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
raise ValueError(
@@ -546,8 +548,22 @@ class CrossAttnDownBlock2D(nn.Module):
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, context=encoder_hidden_states)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn), hidden_states, encoder_hidden_states
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, context=encoder_hidden_states)
output_states += (hidden_states,)
if self.downsamplers is not None:
@@ -609,11 +625,24 @@ class DownBlock2D(nn.Module):
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None):
output_states = ()
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
output_states += (hidden_states,)
if self.downsamplers is not None:
@@ -1072,6 +1101,8 @@ class CrossAttnUpBlock2D(nn.Module):
else:
self.upsamplers = None
self.gradient_checkpointing = False
def set_attention_slice(self, slice_size):
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
raise ValueError(
@@ -1087,15 +1118,36 @@ class CrossAttnUpBlock2D(nn.Module):
for attn in self.attentions:
attn._set_attention_slice(slice_size)
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None):
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
res_hidden_states_tuple,
temb=None,
encoder_hidden_states=None,
):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, context=encoder_hidden_states)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn), hidden_states, encoder_hidden_states
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, context=encoder_hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
@@ -1150,6 +1202,8 @@ class UpBlock2D(nn.Module):
else:
self.upsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
for resnet in self.resnets:
# pop res hidden states
@@ -1157,7 +1211,17 @@ class UpBlock2D(nn.Module):
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:

View File

@@ -19,6 +19,26 @@ from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
class FlaxCrossAttnDownBlock2D(nn.Module):
r"""
Cross Attention 2D Downsizing block - original architecture from Unet transformers:
https://arxiv.org/abs/2103.06104
Parameters:
in_channels (:obj:`int`):
Input channels
out_channels (:obj:`int`):
Output channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
Number of attention heads of each spatial transformer block
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsampling layer before each final output
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int
out_channels: int
dropout: float = 0.0
@@ -55,7 +75,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
self.attentions = attentions
if self.add_downsample:
self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
output_states = ()
@@ -66,13 +86,30 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
output_states += (hidden_states,)
if self.add_downsample:
hidden_states = self.downsample(hidden_states)
hidden_states = self.downsamplers_0(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
class FlaxDownBlock2D(nn.Module):
r"""
Flax 2D downsizing block
Parameters:
in_channels (:obj:`int`):
Input channels
out_channels (:obj:`int`):
Output channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsampling layer before each final output
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int
out_channels: int
dropout: float = 0.0
@@ -96,7 +133,7 @@ class FlaxDownBlock2D(nn.Module):
self.resnets = resnets
if self.add_downsample:
self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, temb, deterministic=True):
output_states = ()
@@ -106,13 +143,33 @@ class FlaxDownBlock2D(nn.Module):
output_states += (hidden_states,)
if self.add_downsample:
hidden_states = self.downsample(hidden_states)
hidden_states = self.downsamplers_0(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
class FlaxCrossAttnUpBlock2D(nn.Module):
r"""
Cross Attention 2D Upsampling block - original architecture from Unet transformers:
https://arxiv.org/abs/2103.06104
Parameters:
in_channels (:obj:`int`):
Input channels
out_channels (:obj:`int`):
Output channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
Number of attention heads of each spatial transformer block
add_upsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add upsampling layer before each final output
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int
out_channels: int
prev_output_channel: int
@@ -151,7 +208,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
self.attentions = attentions
if self.add_upsample:
self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True):
for resnet, attn in zip(self.resnets, self.attentions):
@@ -164,12 +221,31 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
if self.add_upsample:
hidden_states = self.upsample(hidden_states)
hidden_states = self.upsamplers_0(hidden_states)
return hidden_states
class FlaxUpBlock2D(nn.Module):
r"""
Flax 2D upsampling block
Parameters:
in_channels (:obj:`int`):
Input channels
out_channels (:obj:`int`):
Output channels
prev_output_channel (:obj:`int`):
Output channels from the previous block
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsampling layer before each final output
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int
out_channels: int
prev_output_channel: int
@@ -196,7 +272,7 @@ class FlaxUpBlock2D(nn.Module):
self.resnets = resnets
if self.add_upsample:
self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True):
for resnet in self.resnets:
@@ -208,12 +284,27 @@ class FlaxUpBlock2D(nn.Module):
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
if self.add_upsample:
hidden_states = self.upsample(hidden_states)
hidden_states = self.upsamplers_0(hidden_states)
return hidden_states
class FlaxUNetMidBlock2DCrossAttn(nn.Module):
r"""
Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104
Parameters:
in_channels (:obj:`int`):
Input channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
Number of attention heads of each spatial transformer block
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int
dropout: float = 0.0
num_layers: int = 1

View File

@@ -0,0 +1,812 @@
# JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers
import math
from functools import partial
from typing import Tuple
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from ..configuration_utils import ConfigMixin, flax_register_to_config
from ..modeling_flax_utils import FlaxModelMixin
from ..utils import BaseOutput
@flax.struct.dataclass
class FlaxDecoderOutput(BaseOutput):
"""
Output of decoding method.
Args:
sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
Decoded output sample of the model. Output of the last layer of the model.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
sample: jnp.ndarray
@flax.struct.dataclass
class FlaxAutoencoderKLOutput(BaseOutput):
"""
Output of AutoencoderKL encoding method.
Args:
latent_dist (`FlaxDiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution`.
`FlaxDiagonalGaussianDistribution` allows for sampling latents from the distribution.
"""
latent_dist: "FlaxDiagonalGaussianDistribution"
class FlaxUpsample2D(nn.Module):
"""
Flax implementation of 2D Upsample layer
Args:
in_channels (`int`):
Input channels
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int
dtype: jnp.dtype = jnp.float32
def setup(self):
self.conv = nn.Conv(
self.in_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)
def __call__(self, hidden_states):
batch, height, width, channels = hidden_states.shape
hidden_states = jax.image.resize(
hidden_states,
shape=(batch, height * 2, width * 2, channels),
method="nearest",
)
hidden_states = self.conv(hidden_states)
return hidden_states
class FlaxDownsample2D(nn.Module):
"""
Flax implementation of 2D Downsample layer
Args:
in_channels (`int`):
Input channels
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int
dtype: jnp.dtype = jnp.float32
def setup(self):
self.conv = nn.Conv(
self.in_channels,
kernel_size=(3, 3),
strides=(2, 2),
padding="VALID",
dtype=self.dtype,
)
def __call__(self, hidden_states):
pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
hidden_states = jnp.pad(hidden_states, pad_width=pad)
hidden_states = self.conv(hidden_states)
return hidden_states
class FlaxResnetBlock2D(nn.Module):
"""
Flax implementation of 2D Resnet Block.
Args:
in_channels (`int`):
Input channels
out_channels (`int`):
Output channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`):
Whether to use `nin_shortcut`. This activates a new layer inside ResNet block
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int
out_channels: int = None
dropout: float = 0.0
use_nin_shortcut: bool = None
dtype: jnp.dtype = jnp.float32
def setup(self):
out_channels = self.in_channels if self.out_channels is None else self.out_channels
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
self.conv1 = nn.Conv(
out_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)
self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
self.dropout_layer = nn.Dropout(self.dropout)
self.conv2 = nn.Conv(
out_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)
use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
self.conv_shortcut = None
if use_nin_shortcut:
self.conv_shortcut = nn.Conv(
out_channels,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
def __call__(self, hidden_states, deterministic=True):
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = nn.swish(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states = nn.swish(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
residual = self.conv_shortcut(residual)
return hidden_states + residual
class FlaxAttentionBlock(nn.Module):
r"""
Flax Convolutional based multi-head attention block for diffusion-based VAE.
Parameters:
channels (:obj:`int`):
Input channels
num_head_channels (:obj:`int`, *optional*, defaults to `None`):
Number of attention heads
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
channels: int
num_head_channels: int = None
dtype: jnp.dtype = jnp.float32
def setup(self):
self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1
dense = partial(nn.Dense, self.channels, dtype=self.dtype)
self.group_norm = nn.GroupNorm(num_groups=32, epsilon=1e-6)
self.query, self.key, self.value = dense(), dense(), dense()
self.proj_attn = dense()
def transpose_for_scores(self, projection):
new_projection_shape = projection.shape[:-1] + (self.num_heads, -1)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D)
new_projection = projection.reshape(new_projection_shape)
# (B, T, H, D) -> (B, H, T, D)
new_projection = jnp.transpose(new_projection, (0, 2, 1, 3))
return new_projection
def __call__(self, hidden_states):
residual = hidden_states
batch, height, width, channels = hidden_states.shape
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.reshape((batch, height * width, channels))
query = self.query(hidden_states)
key = self.key(hidden_states)
value = self.value(hidden_states)
# transpose
query = self.transpose_for_scores(query)
key = self.transpose_for_scores(key)
value = self.transpose_for_scores(value)
# compute attentions
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
attn_weights = jnp.einsum("...qc,...kc->...qk", query * scale, key * scale)
attn_weights = nn.softmax(attn_weights, axis=-1)
# attend to values
hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights)
hidden_states = jnp.transpose(hidden_states, (0, 2, 1, 3))
new_hidden_states_shape = hidden_states.shape[:-2] + (self.channels,)
hidden_states = hidden_states.reshape(new_hidden_states_shape)
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.reshape((batch, height, width, channels))
hidden_states = hidden_states + residual
return hidden_states
class FlaxDownEncoderBlock2D(nn.Module):
r"""
Flax Resnet blocks-based Encoder block for diffusion-based VAE.
Parameters:
in_channels (:obj:`int`):
Input channels
out_channels (:obj:`int`):
Output channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of Resnet layer block
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsample layer
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int
out_channels: int
dropout: float = 0.0
num_layers: int = 1
add_downsample: bool = True
dtype: jnp.dtype = jnp.float32
def setup(self):
resnets = []
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels
res_block = FlaxResnetBlock2D(
in_channels=in_channels,
out_channels=self.out_channels,
dropout=self.dropout,
dtype=self.dtype,
)
resnets.append(res_block)
self.resnets = resnets
if self.add_downsample:
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, deterministic=True):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, deterministic=deterministic)
if self.add_downsample:
hidden_states = self.downsamplers_0(hidden_states)
return hidden_states
class FlaxUpEncoderBlock2D(nn.Module):
r"""
Flax Resnet blocks-based Encoder block for diffusion-based VAE.
Parameters:
in_channels (:obj:`int`):
Input channels
out_channels (:obj:`int`):
Output channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of Resnet layer block
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsample layer
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int
out_channels: int
dropout: float = 0.0
num_layers: int = 1
add_upsample: bool = True
dtype: jnp.dtype = jnp.float32
def setup(self):
resnets = []
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels
res_block = FlaxResnetBlock2D(
in_channels=in_channels,
out_channels=self.out_channels,
dropout=self.dropout,
dtype=self.dtype,
)
resnets.append(res_block)
self.resnets = resnets
if self.add_upsample:
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, deterministic=True):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, deterministic=deterministic)
if self.add_upsample:
hidden_states = self.upsamplers_0(hidden_states)
return hidden_states
class FlaxUNetMidBlock2D(nn.Module):
r"""
Flax Unet Mid-Block module.
Parameters:
in_channels (:obj:`int`):
Input channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of Resnet layer block
attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`):
Number of attention heads for each attention block
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int
dropout: float = 0.0
num_layers: int = 1
attn_num_head_channels: int = 1
dtype: jnp.dtype = jnp.float32
def setup(self):
# there is always at least one resnet
resnets = [
FlaxResnetBlock2D(
in_channels=self.in_channels,
out_channels=self.in_channels,
dropout=self.dropout,
dtype=self.dtype,
)
]
attentions = []
for _ in range(self.num_layers):
attn_block = FlaxAttentionBlock(
channels=self.in_channels, num_head_channels=self.attn_num_head_channels, dtype=self.dtype
)
attentions.append(attn_block)
res_block = FlaxResnetBlock2D(
in_channels=self.in_channels,
out_channels=self.in_channels,
dropout=self.dropout,
dtype=self.dtype,
)
resnets.append(res_block)
self.resnets = resnets
self.attentions = attentions
def __call__(self, hidden_states, deterministic=True):
hidden_states = self.resnets[0](hidden_states, deterministic=deterministic)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states)
hidden_states = resnet(hidden_states, deterministic=deterministic)
return hidden_states
class FlaxEncoder(nn.Module):
r"""
Flax Implementation of VAE Encoder.
This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
in_channels (:obj:`int`, *optional*, defaults to 3):
Input channels
out_channels (:obj:`int`, *optional*, defaults to 3):
Output channels
down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`):
DownEncoder block type
block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`):
Tuple containing the number of output channels for each block
layers_per_block (:obj:`int`, *optional*, defaults to `2`):
Number of Resnet layer for each block
norm_num_groups (:obj:`int`, *optional*, defaults to `2`):
norm num group
act_fn (:obj:`str`, *optional*, defaults to `silu`):
Activation function
double_z (:obj:`bool`, *optional*, defaults to `False`):
Whether to double the last output channels
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int = 3
out_channels: int = 3
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
block_out_channels: Tuple[int] = (64,)
layers_per_block: int = 2
norm_num_groups: int = 32
act_fn: str = "silu"
double_z: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
block_out_channels = self.block_out_channels
# in
self.conv_in = nn.Conv(
block_out_channels[0],
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)
# downsampling
down_blocks = []
output_channel = block_out_channels[0]
for i, _ in enumerate(self.down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = FlaxDownEncoderBlock2D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=self.layers_per_block,
add_downsample=not is_final_block,
dtype=self.dtype,
)
down_blocks.append(down_block)
self.down_blocks = down_blocks
# middle
self.mid_block = FlaxUNetMidBlock2D(
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
)
# end
conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels
self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
self.conv_out = nn.Conv(
conv_out_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)
def __call__(self, sample, deterministic: bool = True):
# in
sample = self.conv_in(sample)
# downsampling
for block in self.down_blocks:
sample = block(sample, deterministic=deterministic)
# middle
sample = self.mid_block(sample, deterministic=deterministic)
# end
sample = self.conv_norm_out(sample)
sample = nn.swish(sample)
sample = self.conv_out(sample)
return sample
class FlaxDecoder(nn.Module):
r"""
Flax Implementation of VAE Decoder.
This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
in_channels (:obj:`int`, *optional*, defaults to 3):
Input channels
out_channels (:obj:`int`, *optional*, defaults to 3):
Output channels
up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`):
UpDecoder block type
block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`):
Tuple containing the number of output channels for each block
layers_per_block (:obj:`int`, *optional*, defaults to `2`):
Number of Resnet layer for each block
norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
norm num group
act_fn (:obj:`str`, *optional*, defaults to `silu`):
Activation function
double_z (:obj:`bool`, *optional*, defaults to `False`):
Whether to double the last output channels
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
parameters `dtype`
"""
in_channels: int = 3
out_channels: int = 3
up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
block_out_channels: int = (64,)
layers_per_block: int = 2
norm_num_groups: int = 32
act_fn: str = "silu"
dtype: jnp.dtype = jnp.float32
def setup(self):
block_out_channels = self.block_out_channels
# z to block_in
self.conv_in = nn.Conv(
block_out_channels[-1],
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)
# middle
self.mid_block = FlaxUNetMidBlock2D(
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
)
# upsampling
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
up_blocks = []
for i, _ in enumerate(self.up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
up_block = FlaxUpEncoderBlock2D(
in_channels=prev_output_channel,
out_channels=output_channel,
num_layers=self.layers_per_block + 1,
add_upsample=not is_final_block,
dtype=self.dtype,
)
up_blocks.append(up_block)
prev_output_channel = output_channel
self.up_blocks = up_blocks
# end
self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
self.conv_out = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)
def __call__(self, sample, deterministic: bool = True):
# z to block_in
sample = self.conv_in(sample)
# middle
sample = self.mid_block(sample, deterministic=deterministic)
# upsampling
for block in self.up_blocks:
sample = block(sample, deterministic=deterministic)
sample = self.conv_norm_out(sample)
sample = nn.swish(sample)
sample = self.conv_out(sample)
return sample
class FlaxDiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
# Last axis to account for channels-last
self.mean, self.logvar = jnp.split(parameters, 2, axis=-1)
self.logvar = jnp.clip(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = jnp.exp(0.5 * self.logvar)
self.var = jnp.exp(self.logvar)
if self.deterministic:
self.var = self.std = jnp.zeros_like(self.mean)
def sample(self, key):
return self.mean + self.std * jax.random.normal(key, self.mean.shape)
def kl(self, other=None):
if self.deterministic:
return jnp.array([0.0])
if other is None:
return 0.5 * jnp.sum(self.mean**2 + self.var - 1.0 - self.logvar, axis=[1, 2, 3])
return 0.5 * jnp.sum(
jnp.square(self.mean - other.mean) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar,
axis=[1, 2, 3],
)
def nll(self, sample, axis=[1, 2, 3]):
if self.deterministic:
return jnp.array([0.0])
logtwopi = jnp.log(2.0 * jnp.pi)
return 0.5 * jnp.sum(logtwopi + self.logvar + jnp.square(sample - self.mean) / self.var, axis=axis)
def mode(self):
return self.mean
@flax_register_to_config
class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
r"""
Flax Implementation of Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational
Bayes by Diederik P. Kingma and Max Welling.
This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
in_channels (:obj:`int`, *optional*, defaults to 3):
Input channels
out_channels (:obj:`int`, *optional*, defaults to 3):
Output channels
down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`):
DownEncoder block type
up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`):
UpDecoder block type
block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`):
Tuple containing the number of output channels for each block
layers_per_block (:obj:`int`, *optional*, defaults to `2`):
Number of Resnet layer for each block
act_fn (:obj:`str`, *optional*, defaults to `silu`):
Activation function
latent_channels (:obj:`int`, *optional*, defaults to `4`):
Latent space channels
norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
Norm num group
sample_size (:obj:`int`, *optional*, defaults to `32`):
Sample input size
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
parameters `dtype`
"""
in_channels: int = 3
out_channels: int = 3
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
block_out_channels: Tuple[int] = (64,)
layers_per_block: int = 1
act_fn: str = "silu"
latent_channels: int = 4
norm_num_groups: int = 32
sample_size: int = 32
dtype: jnp.dtype = jnp.float32
def setup(self):
self.encoder = FlaxEncoder(
in_channels=self.config.in_channels,
out_channels=self.config.latent_channels,
down_block_types=self.config.down_block_types,
block_out_channels=self.config.block_out_channels,
layers_per_block=self.config.layers_per_block,
act_fn=self.config.act_fn,
norm_num_groups=self.config.norm_num_groups,
double_z=True,
dtype=self.dtype,
)
self.decoder = FlaxDecoder(
in_channels=self.config.latent_channels,
out_channels=self.config.out_channels,
up_block_types=self.config.up_block_types,
block_out_channels=self.config.block_out_channels,
layers_per_block=self.config.layers_per_block,
norm_num_groups=self.config.norm_num_groups,
act_fn=self.config.act_fn,
dtype=self.dtype,
)
self.quant_conv = nn.Conv(
2 * self.config.latent_channels,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
self.post_quant_conv = nn.Conv(
self.config.latent_channels,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
# init input tensors
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
params_rng, dropout_rng, gaussian_rng = jax.random.split(rng, 3)
rngs = {"params": params_rng, "dropout": dropout_rng, "gaussian": gaussian_rng}
return self.init(rngs, sample)["params"]
def encode(self, sample, deterministic: bool = True, return_dict: bool = True):
sample = jnp.transpose(sample, (0, 2, 3, 1))
hidden_states = self.encoder(sample, deterministic=deterministic)
moments = self.quant_conv(hidden_states)
posterior = FlaxDiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return FlaxAutoencoderKLOutput(latent_dist=posterior)
def decode(self, latents, deterministic: bool = True, return_dict: bool = True):
if latents.shape[-1] != self.config.latent_channels:
latents = jnp.transpose(latents, (0, 2, 3, 1))
hidden_states = self.post_quant_conv(latents)
hidden_states = self.decoder(hidden_states, deterministic=deterministic)
hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2))
if not return_dict:
return (hidden_states,)
return FlaxDecoderOutput(sample=hidden_states)
def __call__(self, sample, sample_posterior=False, deterministic: bool = True, return_dict: bool = True):
posterior = self.encode(sample, deterministic=deterministic, return_dict=return_dict)
if sample_posterior:
rng = self.make_rng("gaussian")
hidden_states = posterior.latent_dist.sample(rng)
else:
hidden_states = posterior.latent_dist.mode()
sample = self.decode(hidden_states, return_dict=return_dict).sample
if not return_dict:
return (sample,)
return FlaxDecoderOutput(sample=sample)

View File

@@ -24,16 +24,13 @@ import numpy as np
from huggingface_hub import hf_hub_download
from .utils import is_onnx_available, logging
from .utils import ONNX_WEIGHTS_NAME, is_onnx_available, logging
if is_onnx_available():
import onnxruntime as ort
ONNX_WEIGHTS_NAME = "model.onnx"
logger = logging.get_logger(__name__)
@@ -49,7 +46,7 @@ class OnnxRuntimeModel:
return self.model.run(None, inputs)
@staticmethod
def load_model(path: Union[str, Path], provider=None):
def load_model(path: Union[str, Path], provider=None, sess_options=None):
"""
Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
@@ -63,7 +60,7 @@ class OnnxRuntimeModel:
logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
provider = "CPUExecutionProvider"
return ort.InferenceSession(path, providers=[provider])
return ort.InferenceSession(path, providers=[provider], sess_options=sess_options)
def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs):
"""
@@ -117,6 +114,7 @@ class OnnxRuntimeModel:
cache_dir: Optional[str] = None,
file_name: Optional[str] = None,
provider: Optional[str] = None,
sess_options: Optional["ort.SessionOptions"] = None,
**kwargs,
):
"""
@@ -146,7 +144,9 @@ class OnnxRuntimeModel:
model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
# load model from local directory
if os.path.isdir(model_id):
model = OnnxRuntimeModel.load_model(os.path.join(model_id, model_file_name), provider=provider)
model = OnnxRuntimeModel.load_model(
os.path.join(model_id, model_file_name), provider=provider, sess_options=sess_options
)
kwargs["model_save_dir"] = Path(model_id)
# load model from hub
else:
@@ -161,7 +161,7 @@ class OnnxRuntimeModel:
)
kwargs["model_save_dir"] = Path(model_cache_path).parent
kwargs["latest_model_name"] = Path(model_cache_path).name
model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider)
model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider, sess_options=sess_options)
return cls(model=model, **kwargs)
@classmethod

View File

@@ -0,0 +1,474 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
import os
from typing import Dict, List, Optional, Union
import numpy as np
import flax
import PIL
from flax.core.frozen_dict import FrozenDict
from huggingface_hub import snapshot_download
from PIL import Image
from tqdm.auto import tqdm
from .configuration_utils import ConfigMixin
from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging
if is_transformers_available():
from transformers import FlaxPreTrainedModel
INDEX_FILE = "diffusion_flax_model.bin"
logger = logging.get_logger(__name__)
LOADABLE_CLASSES = {
"diffusers": {
"FlaxModelMixin": ["save_pretrained", "from_pretrained"],
"SchedulerMixin": ["save_config", "from_config"],
"FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"],
},
"transformers": {
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
"PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
"FlaxPreTrainedModel": ["save_pretrained", "from_pretrained"],
"FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
},
}
ALL_IMPORTABLE_CLASSES = {}
for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
class DummyChecker:
def __init__(self):
self.dummy = True
def import_flax_or_no_model(module, class_name):
try:
# 1. First make sure that if a Flax object is present, import this one
class_obj = getattr(module, "Flax" + class_name)
except AttributeError:
# 2. If this doesn't work, it's not a model and we don't append "Flax"
class_obj = getattr(module, class_name)
except AttributeError:
raise ValueError(f"Neither Flax{class_name} nor {class_name} exist in {module}")
return class_obj
@flax.struct.dataclass
class FlaxImagePipelineOutput(BaseOutput):
"""
Output class for image pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
class FlaxDiffusionPipeline(ConfigMixin):
r"""
Base class for all models.
[`FlaxDiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion
pipelines and handles methods for loading, downloading and saving models as well as a few methods common to all
pipelines to:
- enabling/disabling the progress bar for the denoising iteration
Class attributes:
- **config_name** ([`str`]) -- name of the config file that will store the class and module names of all
components of the diffusion pipeline.
"""
config_name = "model_index.json"
def register_modules(self, **kwargs):
# import it here to avoid circular import
from diffusers import pipelines
for name, module in kwargs.items():
# retrieve library
library = module.__module__.split(".")[0]
# check if the module is a pipeline module
pipeline_dir = module.__module__.split(".")[-2]
path = module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name.
if library not in LOADABLE_CLASSES or is_pipeline_module:
library = pipeline_dir
# retrieve class_name
class_name = module.__class__.__name__
register_dict = {name: (library, class_name)}
# save model index config
self.register_to_config(**register_dict)
# set models
setattr(self, name, module)
def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union[Dict, FrozenDict]):
# TODO: handle inference_state
"""
Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
method. The pipeline can easily be re-loaded using the `[`~FlaxDiffusionPipeline.from_pretrained`]` class
method.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to which to save. Will be created if it doesn't exist.
"""
self.save_config(save_directory)
model_index_dict = dict(self.config)
model_index_dict.pop("_class_name")
model_index_dict.pop("_diffusers_version")
model_index_dict.pop("_module", None)
for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name)
model_cls = sub_model.__class__
save_method_name = None
# search for the model's base class in LOADABLE_CLASSES
for library_name, library_classes in LOADABLE_CLASSES.items():
library = importlib.import_module(library_name)
for base_class, save_load_methods in library_classes.items():
class_candidate = getattr(library, base_class)
if issubclass(model_cls, class_candidate):
# if we found a suitable base class in LOADABLE_CLASSES then grab its save method
save_method_name = save_load_methods[0]
break
if save_method_name is not None:
break
# TODO(Patrick, Suraj): to delete after
if isinstance(sub_model, DummyChecker):
continue
save_method = getattr(sub_model, save_method_name)
expects_params = "params" in set(inspect.signature(save_method).parameters.keys())
if expects_params:
save_method(
os.path.join(save_directory, pipeline_component_name), params=params[pipeline_component_name]
)
else:
save_method(os.path.join(save_directory, pipeline_component_name))
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r"""
Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
task.
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
weights are discarded.
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A string, the *repo id* of a pretrained pipeline hosted inside a model repo on
https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like
`CompVis/ldm-text2im-large-256`.
- A path to a *directory* containing pipeline weights saved using
[`~FlaxDiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`.
dtype (`str` or `jnp.dtype`, *optional*):
Override the default `jnp.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
will be automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
mirror (`str`, *optional*):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information. specify the folder name here.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
specific pipeline class. The overwritten components are then directly passed to the pipelines
`__init__` method. See example below for more information.
<Tip>
Passing `use_auth_token=True`` is required when you want to use a private model, *e.g.*
`"CompVis/stable-diffusion-v1-4"`
</Tip>
<Tip>
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
this method in a firewalled environment.
</Tip>
Examples:
```py
>>> from diffusers import FlaxDiffusionPipeline
>>> # Download pipeline from huggingface.co and cache.
>>> pipeline = FlaxDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
>>> # Download pipeline that requires an authorization token
>>> # For more information on access tokens, please refer to this section
>>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
>>> pipeline = FlaxDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
>>> # Download pipeline, but overwrite scheduler
>>> from diffusers import LMSDiscreteScheduler
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
>>> pipeline = FlaxDiffusionPipeline.from_pretrained(
... "CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True
... )
```
"""
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
from_pt = kwargs.pop("from_pt", False)
dtype = kwargs.pop("dtype", None)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path):
config_dict = cls.get_config_dict(
pretrained_model_name_or_path,
cache_dir=cache_dir,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
)
# make sure we only download sub-folders and `diffusers` filenames
folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name]
# download all allow_patterns
cached_folder = snapshot_download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
allow_patterns=allow_patterns,
)
else:
cached_folder = pretrained_model_name_or_path
config_dict = cls.get_config_dict(cached_folder)
# 2. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
if cls != FlaxDiffusionPipeline:
pipeline_class = cls
else:
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
# some modules can be passed directly to the init
# in this case they are already instantiated in `kwargs`
# extract them here
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {}
# inference_params
params = {}
# import it here to avoid circular import
from diffusers import pipelines
# 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
# TODO(Patrick, Suraj) - delete later
if class_name == "DummyChecker":
library_name = "stable_diffusion"
class_name = "FlaxStableDiffusionSafetyChecker"
is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None
# if the model is in a pipeline module, then we load it from the pipeline
if name in passed_class_obj:
# 1. check that passed_class_obj has correct parent class
if not is_pipeline_module:
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
expected_class_obj = None
for class_name, class_candidate in class_candidates.items():
if issubclass(class_obj, class_candidate):
expected_class_obj = class_candidate
if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
raise ValueError(
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
f" {expected_class_obj}"
)
else:
logger.warn(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
" has the correct type"
)
# set passed class object
loaded_sub_model = passed_class_obj[name]
elif is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
if from_pt:
class_obj = import_flax_or_no_model(pipeline_module, class_name)
else:
class_obj = getattr(pipeline_module, class_name)
importable_classes = ALL_IMPORTABLE_CLASSES
class_candidates = {c: class_obj for c in importable_classes.keys()}
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
if from_pt:
class_obj = import_flax_or_no_model(library, class_name)
else:
class_obj = getattr(library, class_name)
importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
if loaded_sub_model is None:
load_method_name = None
for class_name, class_candidate in class_candidates.items():
if issubclass(class_obj, class_candidate):
load_method_name = importable_classes[class_name][1]
load_method = getattr(class_obj, load_method_name)
# check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)):
loadable_folder = os.path.join(cached_folder, name)
else:
loaded_sub_model = cached_folder
if issubclass(class_obj, FlaxModelMixin):
loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype)
params[name] = loaded_params
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
# make sure we don't initialize the weights to save time
if name == "safety_checker":
loaded_sub_model = DummyChecker()
loaded_params = {}
elif from_pt:
# TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here
loaded_sub_model = load_method(loadable_folder, from_pt=from_pt)
loaded_params = loaded_sub_model.params
del loaded_sub_model._params
else:
loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False)
params[name] = loaded_params
elif issubclass(class_obj, SchedulerMixin):
loaded_sub_model, scheduler_state = load_method(loadable_folder)
params[name] = scheduler_state
else:
loaded_sub_model = load_method(loadable_folder)
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
model = pipeline_class(**init_kwargs, dtype=dtype)
return model, params
@staticmethod
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
# TODO: make it compatible with jax.lax
def progress_bar(self, iterable):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
elif not isinstance(self._progress_bar_config, dict):
raise ValueError(
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
)
return tqdm(iterable, **self._progress_bar_config)
def set_progress_bar_config(self, **kwargs):
self._progress_bar_config = kwargs

View File

@@ -30,10 +30,8 @@ from PIL import Image
from tqdm.auto import tqdm
from .configuration_utils import ConfigMixin
from .modeling_utils import WEIGHTS_NAME
from .onnx_utils import ONNX_WEIGHTS_NAME
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, logging
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
INDEX_FILE = "diffusion_pytorch_model.bin"
@@ -237,8 +235,8 @@ class DiffusionPipeline(ConfigMixin):
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
specific pipeline class. The overritten components are then directly passed to the pipelines `__init__`
method. See example below for more information.
specific pipeline class. The overwritten components are then directly passed to the pipelines
`__init__` method. See example below for more information.
<Tip>
@@ -284,6 +282,7 @@ class DiffusionPipeline(ConfigMixin):
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
provider = kwargs.pop("provider", None)
sess_options = kwargs.pop("sess_options", None)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
@@ -341,6 +340,10 @@ class DiffusionPipeline(ConfigMixin):
# 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
if class_name.startswith("Flax"):
class_name = class_name[4:]
is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None
@@ -396,6 +399,7 @@ class DiffusionPipeline(ConfigMixin):
loading_kwargs["torch_dtype"] = torch_dtype
if issubclass(class_obj, diffusers.OnnxRuntimeModel):
loading_kwargs["provider"] = provider
loading_kwargs["sess_options"] = sess_options
# check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)):

View File

@@ -73,7 +73,7 @@ not be used for training. If you want to store the gradients during the forward
We are more than happy about any contribution to the officially supported pipelines 🤗. We aspire
all of our pipelines to be **self-contained**, **easy-to-tweak**, **beginner-friendly** and for **one-purpose-only**.
- **Self-contained**: A pipeline shall be as self-contained as possible. More specifically, this means that all functionality should be either directly defined in the pipeline file iteslf, should be inherited from (and only from) the [`DiffusionPipeline` class](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L56) or be directly attached to the model and scheduler components of the pipeline.
- **Self-contained**: A pipeline shall be as self-contained as possible. More specifically, this means that all functionality should be either directly defined in the pipeline file itself, should be inherited from (and only from) the [`DiffusionPipeline` class](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L56) or be directly attached to the model and scheduler components of the pipeline.
- **Easy-to-use**: Pipelines should be extremely easy to use - one should be able to load the pipeline and
use it for its designated task, *e.g.* text-to-image generation, in just a couple of lines of code. Most
logic including pre-processing, an unrolled diffusion loop, and post-processing should all happen inside the `__call__` method.

View File

@@ -1,4 +1,4 @@
from ..utils import is_onnx_available, is_transformers_available
from ..utils import is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline
from .latent_diffusion_uncond import LDMPipeline
@@ -7,7 +7,7 @@ from .score_sde_ve import ScoreSdeVePipeline
from .stochastic_karras_ve import KarrasVePipeline
if is_transformers_available():
if is_torch_available() and is_transformers_available():
from .latent_diffusion import LDMTextToImagePipeline
from .stable_diffusion import (
StableDiffusionImg2ImgPipeline,
@@ -17,3 +17,6 @@ if is_transformers_available():
if is_transformers_available() and is_onnx_available():
from .stable_diffusion import StableDiffusionOnnxPipeline
if is_transformers_available() and is_flax_available():
from .stable_diffusion import FlaxStableDiffusionPipeline

View File

@@ -27,7 +27,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
vqvae ([`VQModel`]):
Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
bert ([`LDMBertModel`]):
Text-encoder model based on [BERT](ttps://huggingface.co/docs/transformers/model_doc/bert) architecture.
Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture.
tokenizer (`transformers.BertTokenizer`):
Tokenizer of class
[BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer).
@@ -397,7 +397,7 @@ class LDMBertAttention(nn.Module):
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim)
attn_output = self.out_proj(attn_output)
@@ -478,7 +478,7 @@ class LDMBertEncoderLayer(nn.Module):
class LDMBertPreTrainedModel(PreTrainedModel):
config_class = LDMBertConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_supports_gradient_checkpointing = True
_keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"]
def _init_weights(self, module):

View File

@@ -67,7 +67,7 @@ pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt).sample[0]
image = pipe(prompt).images[0]
image.save("astronaut_rides_horse.png")
```
@@ -89,7 +89,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt).sample[0]
image = pipe(prompt).images[0]
image.save("astronaut_rides_horse.png")
```
@@ -115,7 +115,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt).sample[0]
image = pipe(prompt).images[0]
image.save("astronaut_rides_horse.png")
```

View File

@@ -6,7 +6,7 @@ import numpy as np
import PIL
from PIL import Image
from ...utils import BaseOutput, is_onnx_available, is_transformers_available
from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_transformers_available
@dataclass
@@ -35,3 +35,26 @@ if is_transformers_available():
if is_transformers_available() and is_onnx_available():
from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline
if is_transformers_available() and is_flax_available():
import flax
@flax.struct.dataclass
class FlaxStableDiffusionPipelineOutput(BaseOutput):
"""
Output class for Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: List[bool]
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker

View File

@@ -0,0 +1,217 @@
from typing import Dict, List, Optional, Union
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
from ...pipeline_flax_utils import FlaxDiffusionPipeline
from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler
from . import FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`FlaxAutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`FlaxCLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel),
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`].
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
def __init__(
self,
vae: FlaxAutoencoderKL,
text_encoder: FlaxCLIPTextModel,
tokenizer: CLIPTokenizer,
unet: FlaxUNet2DConditionModel,
scheduler: Union[FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler],
safety_checker: FlaxStableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
dtype: jnp.dtype = jnp.float32,
):
super().__init__()
scheduler = scheduler.set_format("np")
self.dtype = dtype
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
def prepare_inputs(self, prompt: Union[str, List[str]]):
if not isinstance(prompt, (str, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
return text_input.input_ids
def __call__(
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.PRNGKey,
num_inference_steps: Optional[int] = 50,
height: Optional[int] = 512,
width: Optional[int] = 512,
guidance_scale: Optional[float] = 7.5,
latents: Optional[jnp.array] = None,
return_dict: bool = True,
debug: bool = False,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`jnp.array`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
a plain tuple.
Returns:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
element is a list of `bool`s denoting whether the corresponding generated image likely represents
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
"""
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# get prompt text embeddings
text_embeddings = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
# TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
# implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
batch_size = prompt_ids.shape[0]
max_length = prompt_ids.shape[-1]
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings])
latents_shape = (
batch_size,
self.unet.in_channels,
self.unet.sample_size,
self.unet.sample_size,
)
if latents is None:
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
def loop_body(step, args):
latents, scheduler_state = args
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
latents_input = jnp.concatenate([latents] * 2)
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
timestep = jnp.broadcast_to(t, latents_input.shape[0])
# predict the noise residual
noise_pred = self.unet.apply(
{"params": params["unet"]},
jnp.array(latents_input),
jnp.array(timestep, dtype=jnp.int32),
encoder_hidden_states=context,
).sample
# perform guidance
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
return latents, scheduler_state
scheduler_state = self.scheduler.set_timesteps(params["scheduler"], num_inference_steps=num_inference_steps)
if debug:
# run with python for loop
for i in range(num_inference_steps):
latents, scheduler_state = loop_body(i, (latents, scheduler_state))
else:
latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state))
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
# TODO: check when flax vae gets merged into main
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
# image = jnp.asarray(image).transpose(0, 2, 3, 1)
# run safety checker
# TODO: check when flax safety checker gets merged into main
# safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
# image, has_nsfw_concept = self.safety_checker(
# images=image, clip_input=safety_checker_input.pixel_values, params=params["safety_params"]
# )
has_nsfw_concept = False
if not return_dict:
return (image, has_nsfw_concept)
return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

View File

@@ -36,7 +36,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offsensive or harmful.
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
@@ -58,7 +58,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn(
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure "
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
@@ -278,8 +278,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
image = image.cpu().permute(0, 2, 3, 1).numpy()
# run safety checker
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
if output_type == "pil":
image = self.numpy_to_pil(image)

View File

@@ -48,7 +48,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offsensive or harmful.
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
@@ -70,7 +70,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn(
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure "
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
@@ -213,7 +213,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to(self.device)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
# get prompt text embeddings
text_input = self.tokenizer(
@@ -265,8 +265,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
sigma = self.scheduler.sigmas[t_index]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
latent_model_input = latent_model_input.to(self.unet.dtype)
t = t.to(self.unet.dtype)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
@@ -284,14 +282,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents.to(self.vae.dtype)).sample
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
# run safety checker
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
if output_type == "pil":
image = self.numpy_to_pil(image)

View File

@@ -12,7 +12,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, PNDMScheduler
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
@@ -66,7 +66,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offsensive or harmful.
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
@@ -78,7 +78,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler],
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
):
@@ -89,7 +89,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn(
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure "
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
@@ -241,8 +241,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
if isinstance(self.scheduler, LMSDiscreteScheduler):
timesteps = torch.tensor(
[num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
)
else:
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
@@ -287,8 +292,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
latents = init_latents
t_start = max(num_inference_steps - init_timestep + offset, 0)
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
t_index = t_start + i
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[t_index]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
@@ -299,10 +309,15 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor(t_index))
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
latents = (init_latents_proper * mask) + (latents * (1 - mask))
# scale and decode the image latents with vae
@@ -313,8 +328,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
image = image.cpu().permute(0, 2, 3, 1).numpy()
# run safety checker
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
if output_type == "pil":
image = self.numpy_to_pil(image)

View File

@@ -48,20 +48,20 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
# at the cost of increasing the possibility of filtering benign images
adjustment = 0.0
for concet_idx in range(len(special_cos_dist[0])):
concept_cos = special_cos_dist[i][concet_idx]
concept_threshold = self.special_care_embeds_weights[concet_idx].item()
result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["special_scores"][concet_idx] > 0:
result_img["special_care"].append({concet_idx, result_img["special_scores"][concet_idx]})
for concept_idx in range(len(special_cos_dist[0])):
concept_cos = special_cos_dist[i][concept_idx]
concept_threshold = self.special_care_embeds_weights[concept_idx].item()
result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["special_scores"][concept_idx] > 0:
result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
adjustment = 0.01
for concet_idx in range(len(cos_dist[0])):
concept_cos = cos_dist[i][concet_idx]
concept_threshold = self.concept_embeds_weights[concet_idx].item()
result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["concept_scores"][concet_idx] > 0:
result_img["bad_concepts"].append(concet_idx)
for concept_idx in range(len(cos_dist[0])):
concept_cos = cos_dist[i][concept_idx]
concept_threshold = self.concept_embeds_weights[concept_idx].item()
result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["concept_scores"][concept_idx] > 0:
result_img["bad_concepts"].append(concept_idx)
result.append(result_img)

View File

@@ -0,0 +1,147 @@
import warnings
from typing import Optional, Tuple
import numpy as np
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from transformers import CLIPConfig, FlaxPreTrainedModel
from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule
def jax_cosine_distance(emb_1, emb_2, eps=1e-12):
norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T
norm_emb_2 = jnp.divide(emb_2.T, jnp.clip(jnp.linalg.norm(emb_2, axis=1), a_min=eps)).T
return jnp.matmul(norm_emb_1, norm_emb_2.T)
class FlaxStableDiffusionSafetyCheckerModule(nn.Module):
config: CLIPConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.vision_model = FlaxCLIPVisionModule(self.config.vision_config)
self.visual_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype)
self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.config.projection_dim))
self.special_care_embeds = self.param(
"special_care_embeds", jax.nn.initializers.ones, (3, self.config.projection_dim)
)
self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,))
self.special_care_embeds_weights = self.param("special_care_embeds_weights", jax.nn.initializers.ones, (3,))
def __call__(self, clip_input):
pooled_output = self.vision_model(clip_input)[1]
image_embeds = self.visual_projection(pooled_output)
special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds)
cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds)
return special_cos_dist, cos_dist
def filtered_with_scores(self, special_cos_dist, cos_dist, images):
batch_size = special_cos_dist.shape[0]
special_cos_dist = np.asarray(special_cos_dist)
cos_dist = np.asarray(cos_dist)
result = []
for i in range(batch_size):
result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
# increase this value to create a stronger `nfsw` filter
# at the cost of increasing the possibility of filtering benign image inputs
adjustment = 0.0
for concept_idx in range(len(special_cos_dist[0])):
concept_cos = special_cos_dist[i][concept_idx]
concept_threshold = self.special_care_embeds_weights[concept_idx].item()
result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["special_scores"][concept_idx] > 0:
result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
adjustment = 0.01
for concept_idx in range(len(cos_dist[0])):
concept_cos = cos_dist[i][concept_idx]
concept_threshold = self.concept_embeds_weights[concept_idx].item()
result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["concept_scores"][concept_idx] > 0:
result_img["bad_concepts"].append(concept_idx)
result.append(result_img)
has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
images_was_copied = False
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
if has_nsfw_concept:
if not images_was_copied:
images_was_copied = True
images = images.copy()
images[idx] = np.zeros(images[idx].shape) # black image
if any(has_nsfw_concepts):
warnings.warn(
"Potential NSFW content was detected in one or more images. A black image will be returned"
" instead. Try again with a different prompt and/or seed."
)
return images, has_nsfw_concepts
class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
config_class = CLIPConfig
main_input_name = "clip_input"
module_class = FlaxStableDiffusionSafetyCheckerModule
def __init__(
self,
config: CLIPConfig,
input_shape: Optional[Tuple] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
if input_shape is None:
input_shape = (1, 224, 224, 3)
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensor
clip_input = jax.random.normal(rng, input_shape)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
random_params = self.module.init(rngs, clip_input)["params"]
return random_params
def __call__(
self,
clip_input,
params: dict = None,
):
clip_input = jnp.transpose(clip_input, (0, 2, 3, 1))
return self.module.apply(
{"params": params or self.params},
jnp.array(clip_input, dtype=jnp.float32),
rngs={},
)
def filtered_with_scores(self, special_cos_dist, cos_dist, images, params: dict = None):
def _filtered_with_scores(module, special_cos_dist, cos_dist, images):
return module.filtered_with_scores(special_cos_dist, cos_dist, images)
return self.module.apply(
{"params": params or self.params},
special_cos_dist,
cos_dist,
images,
method=_filtered_with_scores,
)

View File

@@ -17,13 +17,33 @@
import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
@dataclass
class DDIMSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
@@ -179,7 +199,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
use_clipped_model_output: bool = False,
generator=None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
) -> Union[DDIMSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
@@ -192,11 +212,11 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): TODO
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
@@ -222,6 +242,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
@@ -260,7 +281,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
def add_noise(
self,

View File

@@ -21,7 +21,6 @@ from typing import Optional, Tuple, Union
import flax
import jax.numpy as jnp
from jax import random
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
@@ -60,11 +59,12 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
class DDIMSchedulerState:
# setable values
timesteps: jnp.ndarray
alphas_cumprod: jnp.ndarray
num_inference_steps: Optional[int] = None
@classmethod
def create(cls, num_train_timesteps: int):
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1])
def create(cls, num_train_timesteps: int, alphas_cumprod: jnp.ndarray):
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1], alphas_cumprod=alphas_cumprod)
@dataclass
@@ -96,9 +96,19 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
set_alpha_to_one (`bool`, default `True`):
if alpha for final step is 1 or the final alpha of the "non-previous" one.
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the value of alpha at step 0.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
"""
@property
def has_state(self):
return True
@register_to_config
def __init__(
self,
@@ -106,12 +116,9 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[jnp.ndarray] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
):
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
if beta_schedule == "linear":
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear":
@@ -124,19 +131,24 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
# HACK for now - clean up later (PVP)
self._alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else float(self._alphas_cumprod[0])
self.state = DDIMSchedulerState.create(num_train_timesteps=num_train_timesteps)
def create_state(self):
return DDIMSchedulerState.create(
num_train_timesteps=self.config.num_train_timesteps, alphas_cumprod=self._alphas_cumprod
)
def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
def _get_variance(self, timestep, prev_timestep, alphas_cumprod):
alpha_prod_t = alphas_cumprod[timestep]
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
@@ -144,9 +156,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
return variance
def set_timesteps(
self, state: DDIMSchedulerState, num_inference_steps: int, offset: int = 0
) -> DDIMSchedulerState:
def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int) -> DDIMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -155,9 +165,9 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
the `FlaxDDIMScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
offset (`int`):
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
"""
offset = self.config.steps_offset
step_ratio = self.config.num_train_timesteps // num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
@@ -172,9 +182,6 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
key: random.KeyArray,
eta: float = 0.0,
use_clipped_model_output: bool = False,
return_dict: bool = True,
) -> Union[FlaxSchedulerOutput, Tuple]:
"""
@@ -213,44 +220,35 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
# - pred_sample_direction -> "direction pointing to x_t"
# - pred_prev_sample -> "x_t-1"
# TODO(Patrick) - eta is always 0.0 for now, allow to be set in step function
eta = 0.0
# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
alphas_cumprod = state.alphas_cumprod
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
alpha_prod_t = alphas_cumprod[timestep]
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
# 4. Clip "predicted x_0"
if self.config.clip_sample:
pred_original_sample = jnp.clip(pred_original_sample, -1, 1)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# 4. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 α_t1)/(1 α_t)) * sqrt(1 α_t/α_t1)
variance = self._get_variance(timestep, prev_timestep)
variance = self._get_variance(timestep, prev_timestep, alphas_cumprod)
std_dev_t = eta * variance ** (0.5)
if use_clipped_model_output:
# the model_output is always re-derived from the clipped x_0 in Glide
model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
# 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if eta > 0:
key = random.split(key, num=1)
noise = random.normal(key=key, shape=model_output.shape)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
prev_sample = prev_sample + variance
if not return_dict:
return (prev_sample, state)
@@ -263,9 +261,14 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod[:, None]
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[:, None]
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples

View File

@@ -1,4 +1,4 @@
# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,13 +15,33 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
@dataclass
class DDPMSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
@@ -177,7 +197,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
predict_epsilon=True,
generator=None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
) -> Union[DDPMSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
@@ -190,11 +210,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
[`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
@@ -242,7 +262,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
if not return_dict:
return (pred_prev_sample,)
return SchedulerOutput(prev_sample=pred_prev_sample)
return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
def add_noise(
self,

View File

@@ -1,4 +1,4 @@
# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -266,9 +266,14 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod[..., None]
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None]
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples

View File

@@ -35,10 +35,14 @@ class KarrasVeOutput(BaseOutput):
denoising loop.
derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Derivative of predicted original image sample (x_0).
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
derivative: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None
class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
@@ -106,7 +110,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
self.schedule = [
(
self.config.sigma_max
self.config.sigma_max**2
* (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
)
for i in self.timesteps
@@ -153,7 +157,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
sigma_hat (`float`): TODO
sigma_prev (`float`): TODO
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check).
Returns:
@@ -170,7 +174,9 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
if not return_dict:
return (sample_prev, derivative)
return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
return KarrasVeOutput(
prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
)
def step_correct(
self,
@@ -192,7 +198,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
derivative (`torch.FloatTensor` or `np.ndarray`): TODO
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
Returns:
prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
@@ -205,7 +211,9 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
if not return_dict:
return (sample_prev, derivative)
return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
return KarrasVeOutput(
prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
)
def add_noise(self, original_samples, noise, timesteps):
raise NotImplementedError()

View File

@@ -47,7 +47,7 @@ class FlaxKarrasVeOutput(BaseOutput):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
derivative (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
Derivate of predicted original image sample (x_0).
Derivative of predicted original image sample (x_0).
state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
"""
@@ -113,7 +113,7 @@ class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin):
timesteps = jnp.arange(0, num_inference_steps)[::-1].copy()
schedule = [
(
self.config.sigma_max
self.config.sigma_max**2
* (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
)
for i in timesteps

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
@@ -20,7 +21,26 @@ import torch
from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
@dataclass
class LMSDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None
class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
@@ -133,7 +153,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample: Union[torch.FloatTensor, np.ndarray],
order: int = 4,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
@@ -144,12 +164,12 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample (`torch.FloatTensor` or `np.ndarray`):
current instance of sample being created by diffusion process.
order: coefficient for multi-step inference.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
[`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
"""
sigma = self.sigmas[timestep]
@@ -175,7 +195,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
def add_noise(
self,

View File

@@ -198,8 +198,11 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sigmas = self.match_shape(state.sigmas[timesteps], noise)
noisy_samples = original_samples + noise * sigmas
sigma = state.sigmas[timesteps].flatten()
while len(sigma.shape) < len(noise.shape):
sigma = sigma[..., None]
noisy_samples = original_samples + noise * sigma
return noisy_samples

View File

@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
@@ -59,7 +59,6 @@ class PNDMSchedulerState:
# setable values
_timesteps: jnp.ndarray
num_inference_steps: Optional[int] = None
_offset: int = 0
prk_timesteps: Optional[jnp.ndarray] = None
plms_timesteps: Optional[jnp.ndarray] = None
timesteps: Optional[jnp.ndarray] = None
@@ -104,8 +103,20 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
skip_prk_steps (`bool`):
allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
before plms steps; defaults to `False`.
set_alpha_to_one (`bool`, default `False`):
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the value of alpha at step 0.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
"""
@property
def has_state(self):
return True
@register_to_config
def __init__(
self,
@@ -115,6 +126,8 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
beta_schedule: str = "linear",
trained_betas: Optional[jnp.ndarray] = None,
skip_prk_steps: bool = False,
set_alpha_to_one: bool = False,
steps_offset: int = 0,
):
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
@@ -132,16 +145,17 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# mainly at formula (9), (12), (13) and the Algorithm 2.
self.pndm_order = 4
self.state = PNDMSchedulerState.create(num_train_timesteps=num_train_timesteps)
def create_state(self):
return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
def set_timesteps(
self, state: PNDMSchedulerState, num_inference_steps: int, offset: int = 0
) -> PNDMSchedulerState:
def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> PNDMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -150,16 +164,15 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
the `FlaxPNDMScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
offset (`int`):
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
"""
offset = self.config.steps_offset
step_ratio = self.config.num_train_timesteps // num_inference_steps
# creates integer timesteps by multiplying by ratio
# rounding to avoid issues when num_inference_step is power of 3
_timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1]
_timesteps = _timesteps + offset
_timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round() + offset
state = state.replace(num_inference_steps=num_inference_steps, _offset=offset, _timesteps=_timesteps)
state = state.replace(num_inference_steps=num_inference_steps, _timesteps=_timesteps)
if self.config.skip_prk_steps:
# for some models like stable diffusion the prk steps can/should be skipped to
@@ -254,7 +267,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
)
diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2
prev_timestep = max(timestep - diff_to_prev, state.prk_timesteps[-1])
prev_timestep = timestep - diff_to_prev
timestep = state.prk_timesteps[state.counter // 4 * 4]
if state.counter % 4 == 0:
@@ -274,7 +287,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
# cur_sample should not be `None`
cur_sample = state.cur_sample if state.cur_sample is not None else sample
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output, state=state)
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
state = state.replace(counter=state.counter + 1)
if not return_dict:
@@ -320,7 +333,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
"for more information."
)
prev_timestep = max(timestep - self.config.num_train_timesteps // state.num_inference_steps, 0)
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
if state.counter != 1:
state = state.replace(ets=state.ets.append(model_output))
@@ -344,7 +357,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]
)
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output, state=state)
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
state = state.replace(counter=state.counter + 1)
if not return_dict:
@@ -352,7 +365,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
def _get_prev_sample(self, sample, timestep, timestep_prev, model_output, state):
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
# this function computes x_(tδ) using the formula of (9)
# Note that x_t needs to be added to both sides of the equation
@@ -365,8 +378,8 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
# sample -> x_t
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(tδ)
alpha_prod_t = self.alphas_cumprod[timestep + 1 - state._offset]
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - state._offset]
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
@@ -395,9 +408,14 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod[..., None]
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None]
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples

View File

@@ -192,14 +192,17 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
# equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
# also equation 47 shows the analog from SDE models to ancestral sampling methods
drift = drift - diffusion[:, None, None, None] ** 2 * model_output
diffusion = diffusion.flatten()
while len(diffusion.shape) < len(sample.shape):
diffusion = diffusion[:, None]
drift = drift - diffusion**2 * model_output
# equation 6: sample noise for the diffusion term of
key = random.split(key, num=1)
noise = random.normal(key=key, shape=sample.shape)
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
# TODO is the variable diffusion the correct scaling term for the noise?
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g
if not return_dict:
return (prev_sample, prev_sample_mean, state)
@@ -248,8 +251,11 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
step_size = step_size * jnp.ones(sample.shape[0])
# compute corrected sample: model_output term and noise term
prev_sample_mean = sample + step_size[:, None, None, None] * model_output
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
step_size = step_size.flatten()
while len(step_size.shape) < len(sample.shape):
step_size = step_size[:, None]
prev_sample_mean = sample + step_size * model_output
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
if not return_dict:
return (prev_sample, state)

View File

@@ -1,7 +1,9 @@
import os
import random
import re
import unittest
from distutils.util import strtobool
from pathlib import Path
from typing import Union
import torch
@@ -92,3 +94,157 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
return image
# --- pytest conf functions --- #
# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once
pytest_opt_registered = {}
def pytest_addoption_shared(parser):
"""
This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there.
It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest`
option.
"""
option = "--make-reports"
if option not in pytest_opt_registered:
parser.addoption(
option,
action="store",
default=False,
help="generate report files. The value of this option is used as a prefix to report names",
)
pytest_opt_registered[option] = 1
def pytest_terminal_summary_main(tr, id):
"""
Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current
directory. The report files are prefixed with the test suite name.
This function emulates --duration and -rA pytest arguments.
This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined
there.
Args:
- tr: `terminalreporter` passed from `conftest.py`
- id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is
needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other.
NB: this functions taps into a private _pytest API and while unlikely, it could break should
pytest do internal changes - also it calls default internal methods of terminalreporter which
can be hijacked by various `pytest-` plugins and interfere.
"""
from _pytest.config import create_terminal_writer
if not len(id):
id = "tests"
config = tr.config
orig_writer = config.get_terminal_writer()
orig_tbstyle = config.option.tbstyle
orig_reportchars = tr.reportchars
dir = "reports"
Path(dir).mkdir(parents=True, exist_ok=True)
report_files = {
k: f"{dir}/{id}_{k}.txt"
for k in [
"durations",
"errors",
"failures_long",
"failures_short",
"failures_line",
"passes",
"stats",
"summary_short",
"warnings",
]
}
# custom durations report
# note: there is no need to call pytest --durations=XX to get this separate report
# adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66
dlist = []
for replist in tr.stats.values():
for rep in replist:
if hasattr(rep, "duration"):
dlist.append(rep)
if dlist:
dlist.sort(key=lambda x: x.duration, reverse=True)
with open(report_files["durations"], "w") as f:
durations_min = 0.05 # sec
f.write("slowest durations\n")
for i, rep in enumerate(dlist):
if rep.duration < durations_min:
f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted")
break
f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n")
def summary_failures_short(tr):
# expecting that the reports were --tb=long (default) so we chop them off here to the last frame
reports = tr.getreports("failed")
if not reports:
return
tr.write_sep("=", "FAILURES SHORT STACK")
for rep in reports:
msg = tr._getfailureheadline(rep)
tr.write_sep("_", msg, red=True, bold=True)
# chop off the optional leading extra frames, leaving only the last one
longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S)
tr._tw.line(longrepr)
# note: not printing out any rep.sections to keep the report short
# use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each
# adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814
# note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g.
# pytest-instafail does that)
# report failures with line/short/long styles
config.option.tbstyle = "auto" # full tb
with open(report_files["failures_long"], "w") as f:
tr._tw = create_terminal_writer(config, f)
tr.summary_failures()
# config.option.tbstyle = "short" # short tb
with open(report_files["failures_short"], "w") as f:
tr._tw = create_terminal_writer(config, f)
summary_failures_short(tr)
config.option.tbstyle = "line" # one line per error
with open(report_files["failures_line"], "w") as f:
tr._tw = create_terminal_writer(config, f)
tr.summary_failures()
with open(report_files["errors"], "w") as f:
tr._tw = create_terminal_writer(config, f)
tr.summary_errors()
with open(report_files["warnings"], "w") as f:
tr._tw = create_terminal_writer(config, f)
tr.summary_warnings() # normal warnings
tr.summary_warnings() # final warnings
tr.reportchars = "wPpsxXEf" # emulate -rA (used in summary_passes() and short_test_summary())
with open(report_files["passes"], "w") as f:
tr._tw = create_terminal_writer(config, f)
tr.summary_passes()
with open(report_files["summary_short"], "w") as f:
tr._tw = create_terminal_writer(config, f)
tr.short_test_summary()
with open(report_files["stats"], "w") as f:
tr._tw = create_terminal_writer(config, f)
tr.summary_stats()
# restore:
tr._tw = orig_writer
tr.reportchars = orig_reportchars
config.option.tbstyle = orig_tbstyle

View File

@@ -47,6 +47,9 @@ default_cache_path = os.path.join(hf_cache_home, "diffusers")
CONFIG_NAME = "config.json"
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
ONNX_WEIGHTS_NAME = "model.onnx"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"

View File

@@ -0,0 +1,11 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from ..utils import DummyObject, requires_backends
class FlaxStableDiffusionPipeline(metaclass=DummyObject):
_backends = ["flax", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax", "transformers"])

View File

@@ -11,6 +11,27 @@ class FlaxModelMixin(metaclass=DummyObject):
requires_backends(self, ["flax"])
class FlaxUNet2DConditionModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxAutoencoderKL(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxDiffusionPipeline(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxDDIMScheduler(metaclass=DummyObject):
_backends = ["flax"]
@@ -46,13 +67,6 @@ class FlaxPNDMScheduler(metaclass=DummyObject):
requires_backends(self, ["flax"])
class FlaxUNet2DConditionModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxScoreSdeVeScheduler(metaclass=DummyObject):
_backends = ["flax"]

View File

@@ -266,7 +266,7 @@ def reset_format() -> None:
def warning_advice(self, *args, **kwargs):
"""
This method is identical to `logger.warninging()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this
This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this
warning will not be printed
"""
no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False)

View File

@@ -59,10 +59,17 @@ class BaseOutput(OrderedDict):
if not len(class_fields):
raise ValueError(f"{self.__class__.__name__} has no fields.")
for field in class_fields:
v = getattr(self, field.name)
if v is not None:
self[field.name] = v
first_field = getattr(self, class_fields[0].name)
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
if other_fields_are_none and isinstance(first_field, dict):
for key, value in first_field.items():
self[key] = value
else:
for field in class_fields:
v = getattr(self, field.name)
if v is not None:
self[field.name] = v
def __delitem__(self, *args, **kwargs):
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")

44
tests/conftest.py Normal file
View File

@@ -0,0 +1,44 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# tests directory-specific settings - this file is run automatically
# by pytest before any tests are run
import sys
import warnings
from os.path import abspath, dirname, join
# allow having multiple repository checkouts and not needing to remember to rerun
# 'pip install -e .[dev]' when switching between checkouts and running tests.
git_repo_path = abspath(join(dirname(dirname(__file__)), "src"))
sys.path.insert(1, git_repo_path)
# silence FutureWarning warnings in tests since often we can't act on them until
# they become normal warnings - i.e. the tests still need to test the current functionality
warnings.simplefilter(action="ignore", category=FutureWarning)
def pytest_addoption(parser):
from diffusers.testing_utils import pytest_addoption_shared
pytest_addoption_shared(parser)
def pytest_terminal_summary(terminalreporter):
from diffusers.testing_utils import pytest_terminal_summary_main
make_reports = terminalreporter.config.getoption("--make-reports")
if make_reports:
pytest_terminal_summary_main(terminalreporter, id=make_reports)

View File

@@ -246,3 +246,21 @@ class ModelTesterMixin:
outputs_tuple = model(**inputs_dict, return_dict=False)
recursive_check(outputs_tuple, outputs_dict)
def test_enable_disable_gradient_checkpointing(self):
if not self.model_class._supports_gradient_checkpointing:
return # Skip test if model does not support gradient checkpointing
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
# at init model should have gradient checkpointing disabled
model = self.model_class(**init_dict)
self.assertFalse(model.is_gradient_checkpointing)
# check enable works
model.enable_gradient_checkpointing()
self.assertTrue(model.is_gradient_checkpointing)
# check disable works
model.disable_gradient_checkpointing()
self.assertFalse(model.is_gradient_checkpointing)

View File

@@ -18,8 +18,8 @@ import unittest
import torch
from diffusers import UNet2DModel
from diffusers.testing_utils import floats_tensor, torch_device
from diffusers import UNet2DConditionModel, UNet2DModel
from diffusers.testing_utils import floats_tensor, slow, torch_device
from .test_modeling_common import ModelTesterMixin
@@ -159,6 +159,82 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
@property
def dummy_input(self):
batch_size = 4
num_channels = 4
sizes = (32, 32)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
@property
def input_shape(self):
return (4, 32, 32)
@property
def output_shape(self):
return (4, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": (32, 64),
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
"cross_attention_dim": 32,
"attention_head_dim": 8,
"out_channels": 4,
"in_channels": 4,
"layers_per_block": 2,
"sample_size": 32,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()
out.sum().backward()
# now we save the output and parameter gradients that we will use for comparison purposes with
# the non-checkpointed run.
output_not_checkpointed = out.data.clone()
grad_not_checkpointed = {}
for name, param in model.named_parameters():
grad_not_checkpointed[name] = param.grad.data.clone()
model.enable_gradient_checkpointing()
out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()
out.sum().backward()
# now we save the output and parameter gradients that we will use for comparison purposes with
# the non-checkpointed run.
output_checkpointed = out.data.clone()
grad_checkpointed = {}
for name, param in model.named_parameters():
grad_checkpointed[name] = param.grad.data.clone()
# compare the output and parameters gradients
self.assertTrue((output_checkpointed == output_not_checkpointed).all())
for name in grad_checkpointed:
self.assertTrue(torch.allclose(grad_checkpointed[name], grad_not_checkpointed[name], atol=5e-5))
# TODO(Patrick) - Re-add this test after having cleaned up LDM
# def test_output_pretrained_spatial_transformer(self):
# model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy-spatial")
@@ -231,6 +307,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@slow
def test_from_pretrained_hub(self):
model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)
self.assertIsNotNone(model)
@@ -244,6 +321,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None"
@slow
def test_output_pretrained_ve_mid(self):
model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256")
model.to(torch_device)

60
tests/test_outputs.py Normal file
View File

@@ -0,0 +1,60 @@
import unittest
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from diffusers.utils.outputs import BaseOutput
@dataclass
class CustomOutput(BaseOutput):
images: Union[List[PIL.Image.Image], np.ndarray]
class ConfigTester(unittest.TestCase):
def test_outputs_single_attribute(self):
outputs = CustomOutput(images=np.random.rand(1, 3, 4, 4))
# check every way of getting the attribute
assert isinstance(outputs.images, np.ndarray)
assert outputs.images.shape == (1, 3, 4, 4)
assert isinstance(outputs["images"], np.ndarray)
assert outputs["images"].shape == (1, 3, 4, 4)
assert isinstance(outputs[0], np.ndarray)
assert outputs[0].shape == (1, 3, 4, 4)
# test with a non-tensor attribute
outputs = CustomOutput(images=[PIL.Image.new("RGB", (4, 4))])
# check every way of getting the attribute
assert isinstance(outputs.images, list)
assert isinstance(outputs.images[0], PIL.Image.Image)
assert isinstance(outputs["images"], list)
assert isinstance(outputs["images"][0], PIL.Image.Image)
assert isinstance(outputs[0], list)
assert isinstance(outputs[0][0], PIL.Image.Image)
def test_outputs_dict_init(self):
# test output reinitialization with a `dict` for compatibility with `accelerate`
outputs = CustomOutput({"images": np.random.rand(1, 3, 4, 4)})
# check every way of getting the attribute
assert isinstance(outputs.images, np.ndarray)
assert outputs.images.shape == (1, 3, 4, 4)
assert isinstance(outputs["images"], np.ndarray)
assert outputs["images"].shape == (1, 3, 4, 4)
assert isinstance(outputs[0], np.ndarray)
assert outputs[0].shape == (1, 3, 4, 4)
# test with a non-tensor attribute
outputs = CustomOutput({"images": [PIL.Image.new("RGB", (4, 4))]})
# check every way of getting the attribute
assert isinstance(outputs.images, list)
assert isinstance(outputs.images[0], PIL.Image.Image)
assert isinstance(outputs["images"], list)
assert isinstance(outputs["images"][0], PIL.Image.Image)
assert isinstance(outputs[0], list)
assert isinstance(outputs[0][0], PIL.Image.Image)

View File

@@ -46,11 +46,10 @@ from diffusers import (
UNet2DModel,
VQModel,
)
from diffusers.modeling_utils import WEIGHTS_NAME
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.testing_utils import floats_tensor, load_image, slow, torch_device
from diffusers.utils import CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
@@ -1105,7 +1104,7 @@ class PipelineTesterMixin(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.26815, 0.1581, 0.2658, 0.23248, 0.1550, 0.2539, 0.1131, 0.1024, 0.0837])
expected_slice = np.array([0.578, 0.5811, 0.5924, 0.5809, 0.587, 0.5886, 0.5861, 0.5802, 0.586])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
@@ -1325,14 +1324,58 @@ class PipelineTesterMixin(unittest.TestCase):
assert image.shape == (512, 512, 3)
assert np.abs(expected_image - image).max() < 1e-2
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_inpaint_pipeline_k_lms(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
)
mask_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
)
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/red_cat_sitting_on_a_park_bench_k_lms.png"
)
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id,
scheduler=lms,
safety_checker=self.dummy_safety_checker,
use_auth_token=True,
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "A red cat sitting on a park bench"
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe(
prompt=prompt,
init_image=init_image,
mask_image=mask_image,
strength=0.75,
guidance_scale=7.5,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (512, 512, 3)
assert np.abs(expected_image - image).max() < 1e-2
@slow
def test_stable_diffusion_onnx(self):
from scripts.convert_stable_diffusion_checkpoint_to_onnx import convert_models
with tempfile.TemporaryDirectory() as tmpdirname:
convert_models("CompVis/stable-diffusion-v1-4", tmpdirname, opset=14)
sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(tmpdirname, provider="CUDAExecutionProvider")
sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CUDAExecutionProvider", use_auth_token=True
)
prompt = "A painting of a squirrel eating a burger"
np.random.seed(0)

View File

@@ -1,101 +0,0 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
from tensorflow.core.protobuf.saved_model_pb2 import SavedModel
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_copies.py
REPO_PATH = "."
# Internal TensorFlow ops that can be safely ignored (mostly specific to a saved model)
INTERNAL_OPS = [
"Assert",
"AssignVariableOp",
"EmptyTensorList",
"MergeV2Checkpoints",
"ReadVariableOp",
"ResourceGather",
"RestoreV2",
"SaveV2",
"ShardedFilename",
"StatefulPartitionedCall",
"StaticRegexFullMatch",
"VarHandleOp",
]
def onnx_compliance(saved_model_path, strict, opset):
saved_model = SavedModel()
onnx_ops = []
with open(os.path.join(REPO_PATH, "utils", "tf_ops", "onnx.json")) as f:
onnx_opsets = json.load(f)["opsets"]
for i in range(1, opset + 1):
onnx_ops.extend(onnx_opsets[str(i)])
with open(saved_model_path, "rb") as f:
saved_model.ParseFromString(f.read())
model_op_names = set()
# Iterate over every metagraph in case there is more than one (a saved model can contain multiple graphs)
for meta_graph in saved_model.meta_graphs:
# Add operations in the graph definition
model_op_names.update(node.op for node in meta_graph.graph_def.node)
# Go through the functions in the graph definition
for func in meta_graph.graph_def.library.function:
# Add operations in each function
model_op_names.update(node.op for node in func.node_def)
# Convert to list, sorted if you want
model_op_names = sorted(model_op_names)
incompatible_ops = []
for op in model_op_names:
if op not in onnx_ops and op not in INTERNAL_OPS:
incompatible_ops.append(op)
if strict and len(incompatible_ops) > 0:
raise Exception(f"Found the following incompatible ops for the opset {opset}:\n" + incompatible_ops)
elif len(incompatible_ops) > 0:
print(f"Found the following incompatible ops for the opset {opset}:")
print(*incompatible_ops, sep="\n")
else:
print(f"The saved model {saved_model_path} can properly be converted with ONNX.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--saved_model_path", help="Path of the saved model to check (the .pb file).")
parser.add_argument(
"--opset", default=12, type=int, help="The ONNX opset against which the model has to be tested."
)
parser.add_argument(
"--framework", choices=["onnx"], default="onnx", help="Frameworks against which to test the saved model."
)
parser.add_argument(
"--strict", action="store_true", help="Whether make the checking strict (raise errors) or not (raise warnings)"
)
args = parser.parse_args()
if args.framework == "onnx":
onnx_compliance(args.saved_model_path, args.strict, args.opset)

View File

@@ -0,0 +1,34 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# this script reports modified .py files under the desired list of top-level sub-dirs passed as a list of arguments, e.g.:
# python ./utils/get_modified_files.py utils src tests examples
#
# it uses git to find the forking point and which files were modified - i.e. files not under git won't be considered
# since the output of this script is fed into Makefile commands it doesn't print a newline after the results
import re
import subprocess
import sys
fork_point_sha = subprocess.check_output("git merge-base main HEAD".split()).decode("utf-8")
modified_files = subprocess.check_output(f"git diff --name-only {fork_point_sha}".split()).decode("utf-8").split()
joined_dirs = "|".join(sys.argv[1:])
regex = re.compile(rf"^({joined_dirs}).*?\.py$")
relevant_modified_files = [x for x in modified_files if regex.match(x)]
print(" ".join(relevant_modified_files), end="")