mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Merge branch 'main' of https://github.com/huggingface/diffusers into dreambooth-example
This commit is contained in:
13
.github/workflows/pr_tests.yml
vendored
13
.github/workflows/pr_tests.yml
vendored
@@ -41,4 +41,15 @@ jobs:
|
||||
|
||||
- name: Run all non-slow selected tests on CPU
|
||||
run: |
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile -s tests/
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=tests_torch_cpu tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_cpu_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: pr_torch_test_reports
|
||||
path: reports
|
||||
|
||||
64
.github/workflows/push_tests.yml
vendored
64
.github/workflows/push_tests.yml
vendored
@@ -49,4 +49,66 @@ jobs:
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s tests/
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=tests_torch_gpu tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_gpu_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: torch_test_reports
|
||||
path: reports
|
||||
|
||||
|
||||
|
||||
run_examples_single_gpu:
|
||||
name: Examples tests
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
machine_type: [ single-gpu ]
|
||||
runs-on: [ self-hosted, docker-gpu, '${{ matrix.machine_type }}' ]
|
||||
container:
|
||||
image: nvcr.io/nvidia/pytorch:22.07-py3
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: NVIDIA-SMI
|
||||
run: |
|
||||
nvidia-smi
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip uninstall -y torch torchvision torchtext
|
||||
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu116
|
||||
python -m pip install -e .[quality,test,training]
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run example tests on GPU
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_gpu examples/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/examples_torch_gpu_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: examples_test_reports
|
||||
path: reports
|
||||
|
||||
@@ -4,8 +4,9 @@
|
||||
[default.extend-identifiers]
|
||||
|
||||
[default.extend-words]
|
||||
NIN_="NIN" # NIN is used in scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py
|
||||
NIN="NIN" # NIN is used in scripts/convert_ncsnpp_original_checkpoint_to_diffusers.py
|
||||
nd="np" # nd may be np (numpy)
|
||||
parms="parms" # parms is used in scripts/convert_original_stable_diffusion_to_diffusers.py
|
||||
|
||||
|
||||
[files]
|
||||
|
||||
@@ -45,3 +45,21 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
|
||||
|
||||
## AutoencoderKL
|
||||
[[autodoc]] AutoencoderKL
|
||||
|
||||
## FlaxModelMixin
|
||||
[[autodoc]] FlaxModelMixin
|
||||
|
||||
## FlaxUNet2DConditionOutput
|
||||
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput
|
||||
|
||||
## FlaxUNet2DConditionModel
|
||||
[[autodoc]] FlaxUNet2DConditionModel
|
||||
|
||||
## FlaxDecoderOutput
|
||||
[[autodoc]] models.vae_flax.FlaxDecoderOutput
|
||||
|
||||
## FlaxAutoencoderKLOutput
|
||||
[[autodoc]] models.vae_flax.FlaxAutoencoderKLOutput
|
||||
|
||||
## FlaxAutoencoderKL
|
||||
[[autodoc]] FlaxAutoencoderKL
|
||||
|
||||
@@ -53,7 +53,7 @@ available a colab notebook to directly try them out.
|
||||
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
|
||||
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
|
||||
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
|
||||
| [stochatic_karras_ve](./stochatic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
|
||||
| [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
|
||||
|
||||
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.
|
||||
|
||||
@@ -85,7 +85,7 @@ not be used for training. If you want to store the gradients during the forward
|
||||
We are more than happy about any contribution to the officially supported pipelines 🤗. We aspire
|
||||
all of our pipelines to be **self-contained**, **easy-to-tweak**, **beginner-friendly** and for **one-purpose-only**.
|
||||
|
||||
- **Self-contained**: A pipeline shall be as self-contained as possible. More specifically, this means that all functionality should be either directly defined in the pipeline file iteslf, should be inherited from (and only from) the [`DiffusionPipeline` class](.../diffusion_pipeline) or be directly attached to the model and scheduler components of the pipeline.
|
||||
- **Self-contained**: A pipeline shall be as self-contained as possible. More specifically, this means that all functionality should be either directly defined in the pipeline file itself, should be inherited from (and only from) the [`DiffusionPipeline` class](.../diffusion_pipeline) or be directly attached to the model and scheduler components of the pipeline.
|
||||
- **Easy-to-use**: Pipelines should be extremely easy to use - one should be able to load the pipeline and
|
||||
use it for its designated task, *e.g.* text-to-image generation, in just a couple of lines of code. Most
|
||||
logic including pre-processing, an unrolled diffusion loop, and post-processing should all happen inside the `__call__` method.
|
||||
|
||||
@@ -44,6 +44,6 @@ available a colab notebook to directly try them out.
|
||||
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
|
||||
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
|
||||
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
|
||||
| [stochatic_karras_ve](./api/pipelines/stochatic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
|
||||
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
|
||||
|
||||
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.
|
||||
|
||||
@@ -27,6 +27,6 @@ pip install diffusers
|
||||
|
||||
### Schedulers
|
||||
|
||||
### Pipeliens
|
||||
### Pipelines
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
|
||||
|
||||
# Unonditional Image Generation
|
||||
# Unconditional Image Generation
|
||||
|
||||
The [`DiffusionPipeline`] is the easiest way to use a pre-trained diffusion system for inference
|
||||
|
||||
|
||||
@@ -2,5 +2,6 @@
|
||||
|
||||
**Community** examples consist of both inference and training examples that have been added by the community.
|
||||
|
||||
| Example | Description | Author | |
|
||||
|:----------|:-------------|:-------------|------:|
|
||||
| Example | Description | Author | Colab |
|
||||
|:----------|:----------------------|:-----------------|----------:|
|
||||
| CLIP Guided Stable Diffusion | Doing CLIP guidance for text to image generation with Stable Diffusion| [Suraj Patil](https://github.com/patil-suraj/) | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb) |
|
||||
|
||||
321
examples/community/clip_guided_stable_diffusion.py
Normal file
321
examples/community/clip_guided_stable_diffusion.py
Normal file
@@ -0,0 +1,321 @@
|
||||
import inspect
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from diffusers import AutoencoderKL, DiffusionPipeline, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
|
||||
class MakeCutouts(nn.Module):
|
||||
def __init__(self, cut_size, cut_power=1.0):
|
||||
super().__init__()
|
||||
|
||||
self.cut_size = cut_size
|
||||
self.cut_power = cut_power
|
||||
|
||||
def forward(self, pixel_values, num_cutouts):
|
||||
sideY, sideX = pixel_values.shape[2:4]
|
||||
max_size = min(sideX, sideY)
|
||||
min_size = min(sideX, sideY, self.cut_size)
|
||||
cutouts = []
|
||||
for _ in range(num_cutouts):
|
||||
size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size)
|
||||
offsetx = torch.randint(0, sideX - size + 1, ())
|
||||
offsety = torch.randint(0, sideY - size + 1, ())
|
||||
cutout = pixel_values[:, :, offsety : offsety + size, offsetx : offsetx + size]
|
||||
cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
|
||||
return torch.cat(cutouts)
|
||||
|
||||
|
||||
def spherical_dist_loss(x, y):
|
||||
x = F.normalize(x, dim=-1)
|
||||
y = F.normalize(y, dim=-1)
|
||||
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
|
||||
|
||||
|
||||
def set_requires_grad(model, value):
|
||||
for param in model.parameters():
|
||||
param.requires_grad = value
|
||||
|
||||
|
||||
class CLIPGuidedStableDiffusion(DiffusionPipeline):
|
||||
"""CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000
|
||||
- https://github.com/Jack000/glid-3-xl
|
||||
- https://github.dev/crowsonkb/k-diffusion
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
clip_model: CLIPModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler],
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("pt")
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
clip_model=clip_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
|
||||
self.make_cutouts = MakeCutouts(feature_extractor.size)
|
||||
|
||||
set_requires_grad(self.text_encoder, False)
|
||||
set_requires_grad(self.clip_model, False)
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
def disable_attention_slicing(self):
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def freeze_vae(self):
|
||||
set_requires_grad(self.vae, False)
|
||||
|
||||
def unfreeze_vae(self):
|
||||
set_requires_grad(self.vae, True)
|
||||
|
||||
def freeze_unet(self):
|
||||
set_requires_grad(self.unet, False)
|
||||
|
||||
def unfreeze_unet(self):
|
||||
set_requires_grad(self.unet, True)
|
||||
|
||||
@torch.enable_grad()
|
||||
def cond_fn(
|
||||
self,
|
||||
latents,
|
||||
timestep,
|
||||
index,
|
||||
text_embeddings,
|
||||
noise_pred_original,
|
||||
text_embeddings_clip,
|
||||
clip_guidance_scale,
|
||||
num_cutouts,
|
||||
use_cutouts=True,
|
||||
):
|
||||
latents = latents.detach().requires_grad_()
|
||||
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
sigma = self.scheduler.sigmas[index]
|
||||
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
|
||||
latent_model_input = latents / ((sigma**2 + 1) ** 0.5)
|
||||
else:
|
||||
latent_model_input = latents
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
if isinstance(self.scheduler, PNDMScheduler):
|
||||
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
# compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
|
||||
|
||||
fac = torch.sqrt(beta_prod_t)
|
||||
sample = pred_original_sample * (fac) + latents * (1 - fac)
|
||||
elif isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
sigma = self.scheduler.sigmas[index]
|
||||
sample = latents - sigma * noise_pred
|
||||
else:
|
||||
raise ValueError(f"scheduler type {type(self.scheduler)} not supported")
|
||||
|
||||
sample = 1 / 0.18215 * sample
|
||||
image = self.vae.decode(sample).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
if use_cutouts:
|
||||
image = self.make_cutouts(image, num_cutouts)
|
||||
else:
|
||||
image = transforms.Resize(self.feature_extractor.size)(image)
|
||||
image = self.normalize(image)
|
||||
|
||||
image_embeddings_clip = self.clip_model.get_image_features(image).float()
|
||||
image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
|
||||
|
||||
if use_cutouts:
|
||||
dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip)
|
||||
dists = dists.view([num_cutouts, sample.shape[0], -1])
|
||||
loss = dists.sum(2).mean(0).sum() * clip_guidance_scale
|
||||
else:
|
||||
loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale
|
||||
|
||||
grads = -torch.autograd.grad(loss, latents)[0]
|
||||
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
latents = latents.detach() + grads * (sigma**2)
|
||||
noise_pred = noise_pred_original
|
||||
else:
|
||||
noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads
|
||||
return noise_pred, latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 512,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
clip_guidance_scale: Optional[float] = 100,
|
||||
clip_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_cutouts: Optional[int] = 4,
|
||||
use_cutouts: Optional[bool] = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
):
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
# get prompt text embeddings
|
||||
text_input = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
||||
|
||||
if clip_guidance_scale > 0:
|
||||
if clip_prompt is not None:
|
||||
clip_text_input = self.tokenizer(
|
||||
clip_prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).input_ids.to(self.device)
|
||||
else:
|
||||
clip_text_input = text_input.input_ids.to(self.device)
|
||||
text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
|
||||
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
max_length = text_input.input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
# get the initial random noise unless the user supplied it
|
||||
|
||||
# Unlike in other pipelines, latents need to be generated in the target device
|
||||
# for 1-to-1 results reproducibility with the CompVis implementation.
|
||||
# However this currently doesn't work in `mps`.
|
||||
latents_device = "cpu" if self.device.type == "mps" else self.device
|
||||
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
|
||||
if latents is None:
|
||||
latents = torch.randn(
|
||||
latents_shape,
|
||||
generator=generator,
|
||||
device=latents_device,
|
||||
)
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
latents = latents.to(self.device)
|
||||
|
||||
# set timesteps
|
||||
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
|
||||
extra_set_kwargs = {}
|
||||
if accepts_offset:
|
||||
extra_set_kwargs["offset"] = 1
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
||||
|
||||
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
latents = latents * self.scheduler.sigmas[0]
|
||||
|
||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
sigma = self.scheduler.sigmas[i]
|
||||
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
# # predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# perform clip guidance
|
||||
if clip_guidance_scale > 0:
|
||||
text_embeddings_for_guidance = (
|
||||
text_embeddings.chunk(2)[0] if do_classifier_free_guidance else text_embeddings
|
||||
)
|
||||
noise_pred, latents = self.cond_fn(
|
||||
latents,
|
||||
t,
|
||||
i,
|
||||
text_embeddings_for_guidance,
|
||||
noise_pred,
|
||||
text_embeddings_clip,
|
||||
clip_guidance_scale,
|
||||
num_cutouts,
|
||||
use_cutouts,
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
latents = self.scheduler.step(noise_pred, i, latents).prev_sample
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image, None)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
|
||||
45
examples/conftest.py
Normal file
45
examples/conftest.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# tests directory-specific settings - this file is run automatically
|
||||
# by pytest before any tests are run
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
from os.path import abspath, dirname, join
|
||||
|
||||
|
||||
# allow having multiple repository checkouts and not needing to remember to rerun
|
||||
# 'pip install -e .[dev]' when switching between checkouts and running tests.
|
||||
git_repo_path = abspath(join(dirname(dirname(dirname(__file__))), "src"))
|
||||
sys.path.insert(1, git_repo_path)
|
||||
|
||||
|
||||
# silence FutureWarning warnings in tests since often we can't act on them until
|
||||
# they become normal warnings - i.e. the tests still need to test the current functionality
|
||||
warnings.simplefilter(action="ignore", category=FutureWarning)
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
from diffusers.testing_utils import pytest_addoption_shared
|
||||
|
||||
pytest_addoption_shared(parser)
|
||||
|
||||
|
||||
def pytest_terminal_summary(terminalreporter):
|
||||
from diffusers.testing_utils import pytest_terminal_summary_main
|
||||
|
||||
make_reports = terminalreporter.config.getoption("--make-reports")
|
||||
if make_reports:
|
||||
pytest_terminal_summary_main(terminalreporter, id=make_reports)
|
||||
124
examples/test_examples.py
Normal file
124
examples/test_examples.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc..
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
from accelerate.utils import write_basic_config
|
||||
from diffusers.testing_utils import slow
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
# These utils relate to ensuring the right error message is received when running scripts
|
||||
class SubprocessCallException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def run_command(command: List[str], return_stdout=False):
|
||||
"""
|
||||
Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
|
||||
if an error occured while running `command`
|
||||
"""
|
||||
try:
|
||||
output = subprocess.check_output(command, stderr=subprocess.STDOUT)
|
||||
if return_stdout:
|
||||
if hasattr(output, "decode"):
|
||||
output = output.decode("utf-8")
|
||||
return output
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise SubprocessCallException(
|
||||
f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
|
||||
) from e
|
||||
|
||||
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class ExamplesTestsAccelerate(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls._tmpdir = tempfile.mkdtemp()
|
||||
cls.configPath = os.path.join(cls._tmpdir, "default_config.yml")
|
||||
|
||||
write_basic_config(save_location=cls.configPath)
|
||||
cls._launch_args = ["accelerate", "launch", "--config_file", cls.configPath]
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super().tearDownClass()
|
||||
shutil.rmtree(cls._tmpdir)
|
||||
|
||||
@slow
|
||||
def test_train_unconditional(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/unconditional_image_generation/train_unconditional.py
|
||||
--dataset_name huggan/few-shot-aurora
|
||||
--resolution 64
|
||||
--output_dir {tmpdir}
|
||||
--train_batch_size 4
|
||||
--num_epochs 1
|
||||
--gradient_accumulation_steps 1
|
||||
--learning_rate 1e-3
|
||||
--lr_warmup_steps 5
|
||||
--mixed_precision fp16
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args, return_stdout=True)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
|
||||
# logging test
|
||||
self.assertTrue(len(os.listdir(os.path.join(tmpdir, "logs", "train_unconditional"))) > 0)
|
||||
|
||||
@slow
|
||||
def test_textual_inversion(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/textual_inversion/textual_inversion.py
|
||||
--pretrained_model_name_or_path CompVis/stable-diffusion-v1-4
|
||||
--use_auth_token
|
||||
--train_data_dir docs/source/imgs
|
||||
--learnable_property object
|
||||
--placeholder_token <cat-toy>
|
||||
--initializer_token toy
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 2
|
||||
--max_train_steps 10
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--mixed_precision fp16
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.bin")))
|
||||
@@ -13,11 +13,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.onnx import export
|
||||
|
||||
import onnx
|
||||
from diffusers import StableDiffusionOnnxPipeline, StableDiffusionPipeline
|
||||
from diffusers.onnx_utils import OnnxRuntimeModel
|
||||
from packaging import version
|
||||
@@ -92,10 +95,11 @@ def convert_models(model_path: str, output_path: str, opset: int):
|
||||
)
|
||||
|
||||
# UNET
|
||||
unet_path = output_path / "unet" / "model.onnx"
|
||||
onnx_export(
|
||||
pipeline.unet,
|
||||
model_args=(torch.randn(2, 4, 64, 64), torch.LongTensor([0, 1]), torch.randn(2, 77, 768), False),
|
||||
output_path=output_path / "unet" / "model.onnx",
|
||||
output_path=unet_path,
|
||||
ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"],
|
||||
output_names=["out_sample"], # has to be different from "sample" for correct tracing
|
||||
dynamic_axes={
|
||||
@@ -106,6 +110,21 @@ def convert_models(model_path: str, output_path: str, opset: int):
|
||||
opset=opset,
|
||||
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split
|
||||
)
|
||||
unet_model_path = str(unet_path.absolute().as_posix())
|
||||
unet_dir = os.path.dirname(unet_model_path)
|
||||
unet = onnx.load(unet_model_path)
|
||||
# clean up existing tensor files
|
||||
shutil.rmtree(unet_dir)
|
||||
os.mkdir(unet_dir)
|
||||
# collate external tensor files into one
|
||||
onnx.save_model(
|
||||
unet,
|
||||
unet_model_path,
|
||||
save_as_external_data=True,
|
||||
all_tensors_to_one_file=True,
|
||||
location="weights.pb",
|
||||
convert_attribute=False,
|
||||
)
|
||||
|
||||
# VAE ENCODER
|
||||
vae_encoder = pipeline.vae
|
||||
|
||||
23
setup.py
23
setup.py
@@ -77,7 +77,7 @@ from setuptools import find_packages, setup
|
||||
# 1. all dependencies should be listed here with their version requirements if any
|
||||
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
|
||||
_deps = [
|
||||
"Pillow",
|
||||
"Pillow<10.0", # keep the PIL.Image.Resampling deprecation away
|
||||
"accelerate>=0.11.0",
|
||||
"black==22.8",
|
||||
"datasets",
|
||||
@@ -90,8 +90,9 @@ _deps = [
|
||||
"isort>=5.5.4",
|
||||
"jax>=0.2.8,!=0.3.2,<=0.3.6",
|
||||
"jaxlib>=0.1.65,<=0.3.6",
|
||||
"modelcards==0.1.4",
|
||||
"modelcards>=0.1.4",
|
||||
"numpy",
|
||||
"onnxruntime-gpu",
|
||||
"pytest",
|
||||
"pytest-timeout",
|
||||
"pytest-xdist",
|
||||
@@ -100,6 +101,7 @@ _deps = [
|
||||
"requests",
|
||||
"tensorboard",
|
||||
"torch>=1.4",
|
||||
"torchvision",
|
||||
"transformers>=4.21.0",
|
||||
]
|
||||
|
||||
@@ -171,10 +173,19 @@ extras = {}
|
||||
|
||||
|
||||
extras = {}
|
||||
extras["quality"] = ["black==22.8", "isort>=5.5.4", "flake8>=3.8.3", "hf-doc-builder"]
|
||||
extras["docs"] = ["hf-doc-builder"]
|
||||
extras["training"] = ["accelerate", "datasets", "tensorboard", "modelcards"]
|
||||
extras["test"] = ["datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "transformers"]
|
||||
extras["quality"] = deps_list("black", "isort", "flake8", "hf-doc-builder")
|
||||
extras["docs"] = deps_list("hf-doc-builder")
|
||||
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
|
||||
extras["test"] = deps_list(
|
||||
"datasets",
|
||||
"onnxruntime-gpu",
|
||||
"pytest",
|
||||
"pytest-timeout",
|
||||
"pytest-xdist",
|
||||
"scipy",
|
||||
"torchvision",
|
||||
"transformers"
|
||||
)
|
||||
extras["torch"] = deps_list("torch")
|
||||
|
||||
if os.name == "nt": # windows
|
||||
|
||||
@@ -65,6 +65,8 @@ else:
|
||||
if is_flax_available():
|
||||
from .modeling_flax_utils import FlaxModelMixin
|
||||
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
|
||||
from .models.vae_flax import FlaxAutoencoderKL
|
||||
from .pipeline_flax_utils import FlaxDiffusionPipeline
|
||||
from .schedulers import (
|
||||
FlaxDDIMScheduler,
|
||||
FlaxDDPMScheduler,
|
||||
@@ -75,3 +77,8 @@ if is_flax_available():
|
||||
)
|
||||
else:
|
||||
from .utils.dummy_flax_objects import * # noqa F403
|
||||
|
||||
if is_flax_available() and is_transformers_available():
|
||||
from .pipelines import FlaxStableDiffusionPipeline
|
||||
else:
|
||||
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
|
||||
|
||||
@@ -154,15 +154,25 @@ class ConfigMixin:
|
||||
|
||||
"""
|
||||
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
# Allow dtype to be specified on initialization
|
||||
if "dtype" in unused_kwargs:
|
||||
init_dict["dtype"] = unused_kwargs.pop("dtype")
|
||||
|
||||
# Return model and optionally state and/or unused_kwargs
|
||||
model = cls(**init_dict)
|
||||
return_tuple = (model,)
|
||||
|
||||
# Flax schedulers have a state, so return it.
|
||||
if cls.__name__.startswith("Flax") and hasattr(model, "create_state") and getattr(model, "has_state", False):
|
||||
state = model.create_state()
|
||||
return_tuple += (state,)
|
||||
|
||||
if return_unused_kwargs:
|
||||
return model, unused_kwargs
|
||||
return return_tuple + (unused_kwargs,)
|
||||
else:
|
||||
return model
|
||||
return return_tuple if len(return_tuple) > 1 else model
|
||||
|
||||
@classmethod
|
||||
def get_config_dict(
|
||||
@@ -272,7 +282,7 @@ class ConfigMixin:
|
||||
# remove general kwargs if present in dict
|
||||
if "kwargs" in expected_keys:
|
||||
expected_keys.remove("kwargs")
|
||||
# remove flax interal keys
|
||||
# remove flax internal keys
|
||||
if hasattr(cls, "_flax_internal_args"):
|
||||
for arg in cls._flax_internal_args:
|
||||
expected_keys.remove(arg)
|
||||
@@ -446,6 +456,9 @@ def flax_register_to_config(cls):
|
||||
|
||||
# Make sure init_kwargs override default kwargs
|
||||
new_kwargs = {**default_kwargs, **init_kwargs}
|
||||
# dtype should be part of `init_kwargs`, but not `new_kwargs`
|
||||
if "dtype" in new_kwargs:
|
||||
new_kwargs.pop("dtype")
|
||||
|
||||
# Get positional arguments aligned with kwargs
|
||||
for i, arg in enumerate(args):
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# 1. modify the `_deps` dict in setup.py
|
||||
# 2. run `make deps_table_update``
|
||||
deps = {
|
||||
"Pillow": "Pillow",
|
||||
"Pillow": "Pillow<10.0",
|
||||
"accelerate": "accelerate>=0.11.0",
|
||||
"black": "black==22.8",
|
||||
"datasets": "datasets",
|
||||
@@ -15,8 +15,10 @@ deps = {
|
||||
"isort": "isort>=5.5.4",
|
||||
"jax": "jax>=0.2.8,!=0.3.2,<=0.3.6",
|
||||
"jaxlib": "jaxlib>=0.1.65,<=0.3.6",
|
||||
"modelcards": "modelcards==0.1.4",
|
||||
"modelcards": "modelcards>=0.1.4",
|
||||
"numpy": "numpy",
|
||||
"onnxruntime": "onnxruntime",
|
||||
"onnxruntime-gpu": "onnxruntime-gpu",
|
||||
"pytest": "pytest",
|
||||
"pytest-timeout": "pytest-timeout",
|
||||
"pytest-xdist": "pytest-xdist",
|
||||
@@ -25,5 +27,6 @@ deps = {
|
||||
"requests": "requests",
|
||||
"tensorboard": "tensorboard",
|
||||
"torch": "torch>=1.4",
|
||||
"torchvision": "torchvision",
|
||||
"transformers": "transformers>=4.21.0",
|
||||
}
|
||||
|
||||
117
src/diffusers/modeling_flax_pytorch_utils.py
Normal file
117
src/diffusers/modeling_flax_pytorch_utils.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch - Flax general utilities."""
|
||||
import re
|
||||
|
||||
import jax.numpy as jnp
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax.random import PRNGKey
|
||||
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def rename_key(key):
|
||||
regex = r"\w+[.]\d+"
|
||||
pats = re.findall(regex, key)
|
||||
for pat in pats:
|
||||
key = key.replace(pat, "_".join(pat.split(".")))
|
||||
return key
|
||||
|
||||
|
||||
#####################
|
||||
# PyTorch => Flax #
|
||||
#####################
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
|
||||
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
|
||||
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
|
||||
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
|
||||
|
||||
# conv norm or layer norm
|
||||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
||||
if (
|
||||
any("norm" in str_ for str_ in pt_tuple_key)
|
||||
and (pt_tuple_key[-1] == "bias")
|
||||
and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
|
||||
and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
|
||||
):
|
||||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
||||
return renamed_pt_tuple_key, pt_tensor
|
||||
elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
|
||||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
||||
return renamed_pt_tuple_key, pt_tensor
|
||||
|
||||
# embedding
|
||||
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
|
||||
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
|
||||
return renamed_pt_tuple_key, pt_tensor
|
||||
|
||||
# conv layer
|
||||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
||||
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
|
||||
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
|
||||
return renamed_pt_tuple_key, pt_tensor
|
||||
|
||||
# linear layer
|
||||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
||||
if pt_tuple_key[-1] == "weight":
|
||||
pt_tensor = pt_tensor.T
|
||||
return renamed_pt_tuple_key, pt_tensor
|
||||
|
||||
# old PyTorch layer norm weight
|
||||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
|
||||
if pt_tuple_key[-1] == "gamma":
|
||||
return renamed_pt_tuple_key, pt_tensor
|
||||
|
||||
# old PyTorch layer norm bias
|
||||
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
|
||||
if pt_tuple_key[-1] == "beta":
|
||||
return renamed_pt_tuple_key, pt_tensor
|
||||
|
||||
return pt_tuple_key, pt_tensor
|
||||
|
||||
|
||||
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
|
||||
# Step 1: Convert pytorch tensor to numpy
|
||||
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
||||
|
||||
# Step 2: Since the model is stateless, get random Flax params
|
||||
random_flax_params = flax_model.init_weights(PRNGKey(init_key))
|
||||
|
||||
random_flax_state_dict = flatten_dict(random_flax_params)
|
||||
flax_state_dict = {}
|
||||
|
||||
# Need to change some parameters name to match Flax names
|
||||
for pt_key, pt_tensor in pt_state_dict.items():
|
||||
renamed_pt_key = rename_key(pt_key)
|
||||
pt_tuple_key = tuple(renamed_pt_key.split("."))
|
||||
|
||||
# Correctly rename weight parameters
|
||||
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict)
|
||||
|
||||
if flax_key in random_flax_state_dict:
|
||||
if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
|
||||
raise ValueError(
|
||||
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
|
||||
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
|
||||
)
|
||||
|
||||
# also add unexpected weight so that warning is thrown
|
||||
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
|
||||
|
||||
return unflatten_dict(flax_state_dict)
|
||||
@@ -27,12 +27,18 @@ from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||
from requests import HTTPError
|
||||
|
||||
from .modeling_utils import WEIGHTS_NAME
|
||||
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
|
||||
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
|
||||
from .modeling_utils import load_state_dict
|
||||
from .utils import (
|
||||
CONFIG_NAME,
|
||||
DIFFUSERS_CACHE,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
WEIGHTS_NAME,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -45,7 +51,7 @@ class FlaxModelMixin:
|
||||
"""
|
||||
config_name = CONFIG_NAME
|
||||
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
||||
_flax_internal_args = ["name", "parent"]
|
||||
_flax_internal_args = ["name", "parent", "dtype"]
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config, **kwargs):
|
||||
@@ -245,6 +251,8 @@ class FlaxModelMixin:
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
from_pt (`bool`, *optional*, defaults to `False`):
|
||||
Load the model weights from a PyTorch checkpoint save file.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
||||
@@ -272,6 +280,7 @@ class FlaxModelMixin:
|
||||
config = kwargs.pop("config", None)
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
from_pt = kwargs.pop("from_pt", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
@@ -294,32 +303,44 @@ class FlaxModelMixin:
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
# model args
|
||||
dtype=dtype,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Load model
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
if os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
|
||||
pretrained_path_with_subfolder = (
|
||||
pretrained_model_name_or_path
|
||||
if subfolder is None
|
||||
else os.path.join(pretrained_model_name_or_path, subfolder)
|
||||
)
|
||||
if os.path.isdir(pretrained_path_with_subfolder):
|
||||
if from_pt:
|
||||
if not os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_path_with_subfolder} "
|
||||
)
|
||||
model_file = os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)
|
||||
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)):
|
||||
# Load from a Flax checkpoint
|
||||
model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
|
||||
# At this stage we don't have a weight file so we will raise an error.
|
||||
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
|
||||
model_file = os.path.join(pretrained_path_with_subfolder, FLAX_WEIGHTS_NAME)
|
||||
# Check if pytorch weights exist instead
|
||||
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
|
||||
"but there is a file for PyTorch weights."
|
||||
f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
|
||||
" using `from_pt=True`."
|
||||
)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
|
||||
f"{pretrained_model_name_or_path}."
|
||||
f"{pretrained_path_with_subfolder}."
|
||||
)
|
||||
else:
|
||||
try:
|
||||
model_file = hf_hub_download(
|
||||
pretrained_model_name_or_path,
|
||||
filename=FLAX_WEIGHTS_NAME,
|
||||
filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
@@ -369,25 +390,32 @@ class FlaxModelMixin:
|
||||
f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
|
||||
)
|
||||
|
||||
try:
|
||||
with open(model_file, "rb") as state_f:
|
||||
state = from_bytes(cls, state_f.read())
|
||||
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
|
||||
if from_pt:
|
||||
# Step 1: Get the pytorch file
|
||||
pytorch_model_file = load_state_dict(model_file)
|
||||
|
||||
# Step 2: Convert the weights
|
||||
state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model)
|
||||
else:
|
||||
try:
|
||||
with open(model_file) as f:
|
||||
if f.read().startswith("version"):
|
||||
raise OSError(
|
||||
"You seem to have cloned a repository without having git-lfs installed. Please"
|
||||
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
|
||||
" folder you cloned."
|
||||
)
|
||||
else:
|
||||
raise ValueError from e
|
||||
except (UnicodeDecodeError, ValueError):
|
||||
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
|
||||
# make sure all arrays are stored as jnp.ndarray
|
||||
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
|
||||
# https://github.com/google/flax/issues/1261
|
||||
with open(model_file, "rb") as state_f:
|
||||
state = from_bytes(cls, state_f.read())
|
||||
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
|
||||
try:
|
||||
with open(model_file) as f:
|
||||
if f.read().startswith("version"):
|
||||
raise OSError(
|
||||
"You seem to have cloned a repository without having git-lfs installed. Please"
|
||||
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
|
||||
" folder you cloned."
|
||||
)
|
||||
else:
|
||||
raise ValueError from e
|
||||
except (UnicodeDecodeError, ValueError):
|
||||
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
|
||||
# make sure all arrays are stored as jnp.ndarray
|
||||
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
|
||||
# https://github.com/google/flax/issues/1261
|
||||
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
|
||||
|
||||
# flatten dicts
|
||||
@@ -408,7 +436,7 @@ class FlaxModelMixin:
|
||||
)
|
||||
cls._missing_keys = missing_keys
|
||||
|
||||
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
||||
# Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
||||
# matching the weights in the model.
|
||||
mismatched_keys = []
|
||||
for key in state.keys():
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -24,10 +25,7 @@ from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||
from requests import HTTPError
|
||||
|
||||
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
|
||||
|
||||
|
||||
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
|
||||
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -124,10 +122,42 @@ class ModelMixin(torch.nn.Module):
|
||||
"""
|
||||
config_name = CONFIG_NAME
|
||||
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
||||
_supports_gradient_checkpointing = False
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def is_gradient_checkpointing(self) -> bool:
|
||||
"""
|
||||
Whether gradient checkpointing is activated for this model or not.
|
||||
|
||||
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
||||
activations".
|
||||
"""
|
||||
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
"""
|
||||
Activates gradient checkpointing for the current model.
|
||||
|
||||
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
||||
activations".
|
||||
"""
|
||||
if not self._supports_gradient_checkpointing:
|
||||
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
"""
|
||||
Deactivates gradient checkpointing for the current model.
|
||||
|
||||
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
||||
activations".
|
||||
"""
|
||||
if self._supports_gradient_checkpointing:
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
|
||||
@@ -12,6 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .unet_2d import UNet2DModel
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from .vae import AutoencoderKL, VQModel
|
||||
from ..utils import is_flax_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .unet_2d import UNet2DModel
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from .vae import AutoencoderKL, VQModel
|
||||
|
||||
if is_flax_available():
|
||||
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
|
||||
from .vae_flax import FlaxAutoencoderKL
|
||||
|
||||
@@ -249,13 +249,15 @@ class CrossAttention(nn.Module):
|
||||
return tensor
|
||||
|
||||
def forward(self, hidden_states, context=None, mask=None):
|
||||
batch_size, sequence_length, dim = hidden_states.shape
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
query = self.to_q(hidden_states)
|
||||
context = context if context is not None else hidden_states
|
||||
key = self.to_k(context)
|
||||
value = self.to_v(context)
|
||||
|
||||
dim = query.shape[-1]
|
||||
|
||||
query = self.reshape_heads_to_batch_dim(query)
|
||||
key = self.reshape_heads_to_batch_dim(key)
|
||||
value = self.reshape_heads_to_batch_dim(value)
|
||||
|
||||
@@ -17,6 +17,22 @@ import jax.numpy as jnp
|
||||
|
||||
|
||||
class FlaxAttentionBlock(nn.Module):
|
||||
r"""
|
||||
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
|
||||
|
||||
Parameters:
|
||||
query_dim (:obj:`int`):
|
||||
Input hidden states dimension
|
||||
heads (:obj:`int`, *optional*, defaults to 8):
|
||||
Number of heads
|
||||
dim_head (:obj:`int`, *optional*, defaults to 64):
|
||||
Hidden states dimension inside each head
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
||||
Dropout rate
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
|
||||
"""
|
||||
query_dim: int
|
||||
heads: int = 8
|
||||
dim_head: int = 64
|
||||
@@ -32,7 +48,7 @@ class FlaxAttentionBlock(nn.Module):
|
||||
self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
|
||||
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
|
||||
|
||||
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out")
|
||||
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
|
||||
|
||||
def reshape_heads_to_batch_dim(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
@@ -74,6 +90,23 @@ class FlaxAttentionBlock(nn.Module):
|
||||
|
||||
|
||||
class FlaxBasicTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in:
|
||||
https://arxiv.org/abs/1706.03762
|
||||
|
||||
|
||||
Parameters:
|
||||
dim (:obj:`int`):
|
||||
Inner hidden states dimension
|
||||
n_heads (:obj:`int`):
|
||||
Number of heads
|
||||
d_head (:obj:`int`):
|
||||
Hidden states dimension inside each head
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
||||
Dropout rate
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
dim: int
|
||||
n_heads: int
|
||||
d_head: int
|
||||
@@ -82,9 +115,9 @@ class FlaxBasicTransformerBlock(nn.Module):
|
||||
|
||||
def setup(self):
|
||||
# self attention
|
||||
self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
|
||||
self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
|
||||
# cross attention
|
||||
self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
|
||||
self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
|
||||
self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
|
||||
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
||||
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
||||
@@ -93,12 +126,12 @@ class FlaxBasicTransformerBlock(nn.Module):
|
||||
def __call__(self, hidden_states, context, deterministic=True):
|
||||
# self attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn(self.norm1(hidden_states), deterministic=deterministic)
|
||||
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
# cross attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.cross_attn(self.norm2(hidden_states), context, deterministic=deterministic)
|
||||
hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
# feed forward
|
||||
@@ -110,6 +143,25 @@ class FlaxBasicTransformerBlock(nn.Module):
|
||||
|
||||
|
||||
class FlaxSpatialTransformer(nn.Module):
|
||||
r"""
|
||||
A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
|
||||
https://arxiv.org/pdf/1506.02025.pdf
|
||||
|
||||
|
||||
Parameters:
|
||||
in_channels (:obj:`int`):
|
||||
Input number of channels
|
||||
n_heads (:obj:`int`):
|
||||
Number of heads
|
||||
d_head (:obj:`int`):
|
||||
Hidden states dimension inside each head
|
||||
depth (:obj:`int`, *optional*, defaults to 1):
|
||||
Number of transformers block
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
||||
Dropout rate
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
in_channels: int
|
||||
n_heads: int
|
||||
d_head: int
|
||||
@@ -144,7 +196,6 @@ class FlaxSpatialTransformer(nn.Module):
|
||||
|
||||
def __call__(self, hidden_states, context, deterministic=True):
|
||||
batch, height, width, channels = hidden_states.shape
|
||||
# import ipdb; ipdb.set_trace()
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
@@ -163,18 +214,56 @@ class FlaxSpatialTransformer(nn.Module):
|
||||
|
||||
|
||||
class FlaxGluFeedForward(nn.Module):
|
||||
r"""
|
||||
Flax module that encapsulates two Linear layers separated by a gated linear unit activation from:
|
||||
https://arxiv.org/abs/2002.05202
|
||||
|
||||
Parameters:
|
||||
dim (:obj:`int`):
|
||||
Inner hidden states dimension
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
||||
Dropout rate
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
dim: int
|
||||
dropout: float = 0.0
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
# The second linear layer needs to be called
|
||||
# net_2 for now to match the index of the Sequential layer
|
||||
self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
|
||||
self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, deterministic=True):
|
||||
hidden_states = self.net_0(hidden_states)
|
||||
hidden_states = self.net_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxGEGLU(nn.Module):
|
||||
r"""
|
||||
Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
|
||||
https://arxiv.org/abs/2002.05202.
|
||||
|
||||
Parameters:
|
||||
dim (:obj:`int`):
|
||||
Input hidden states dimension
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
||||
Dropout rate
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
dim: int
|
||||
dropout: float = 0.0
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
inner_dim = self.dim * 4
|
||||
self.dense1 = nn.Dense(inner_dim * 2, dtype=self.dtype)
|
||||
self.dense2 = nn.Dense(self.dim, dtype=self.dtype)
|
||||
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, deterministic=True):
|
||||
hidden_states = self.dense1(hidden_states)
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
|
||||
hidden_states = hidden_linear * nn.gelu(hidden_gelu)
|
||||
hidden_states = self.dense2(hidden_states)
|
||||
return hidden_states
|
||||
return hidden_linear * nn.gelu(hidden_gelu)
|
||||
|
||||
@@ -19,7 +19,7 @@ import jax.numpy as jnp
|
||||
|
||||
# This is like models.embeddings.get_timestep_embedding (PyTorch) but
|
||||
# less general (only handles the case we currently need).
|
||||
def get_sinusoidal_embeddings(timesteps, embedding_dim):
|
||||
def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: float = 1):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
|
||||
@@ -29,7 +29,7 @@ def get_sinusoidal_embeddings(timesteps, embedding_dim):
|
||||
embeddings. :return: an [N x dim] tensor of positional embeddings.
|
||||
"""
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = math.log(10000) / (half_dim - freq_shift)
|
||||
emb = jnp.exp(jnp.arange(half_dim) * -emb)
|
||||
emb = timesteps[:, None] * emb[None, :]
|
||||
emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1)
|
||||
@@ -37,6 +37,15 @@ def get_sinusoidal_embeddings(timesteps, embedding_dim):
|
||||
|
||||
|
||||
class FlaxTimestepEmbedding(nn.Module):
|
||||
r"""
|
||||
Time step Embedding Module. Learns embeddings for input time steps.
|
||||
|
||||
Args:
|
||||
time_embed_dim (`int`, *optional*, defaults to `32`):
|
||||
Time step embedding dimension
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
time_embed_dim: int = 32
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
@@ -49,8 +58,16 @@ class FlaxTimestepEmbedding(nn.Module):
|
||||
|
||||
|
||||
class FlaxTimesteps(nn.Module):
|
||||
r"""
|
||||
Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239
|
||||
|
||||
Args:
|
||||
dim (`int`, *optional*, defaults to `32`):
|
||||
Time step embedding dimension
|
||||
"""
|
||||
dim: int = 32
|
||||
freq_shift: float = 1
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, timesteps):
|
||||
return get_sinusoidal_embeddings(timesteps, self.dim)
|
||||
return get_sinusoidal_embeddings(timesteps, self.dim, freq_shift=self.freq_shift)
|
||||
|
||||
@@ -149,7 +149,6 @@ class FirUpsample2D(nn.Module):
|
||||
|
||||
stride = (factor, factor)
|
||||
# Determine data dimensions.
|
||||
stride = [1, 1, factor, factor]
|
||||
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
|
||||
output_padding = (
|
||||
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
|
||||
@@ -161,7 +160,7 @@ class FirUpsample2D(nn.Module):
|
||||
|
||||
# Transpose weights.
|
||||
weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
|
||||
weight = weight[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
|
||||
weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
|
||||
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
|
||||
|
||||
x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0)
|
||||
|
||||
@@ -3,12 +3,21 @@ from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
|
||||
from .unet_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
CrossAttnUpBlock2D,
|
||||
DownBlock2D,
|
||||
UNetMidBlock2DCrossAttn,
|
||||
UpBlock2D,
|
||||
get_down_block,
|
||||
get_up_block,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -54,6 +63,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
@@ -188,6 +199,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
if hasattr(block, "attentions") and block.attentions is not None:
|
||||
block.set_attention_slice(slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
@@ -234,7 +249,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Union
|
||||
|
||||
import flax
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@@ -19,7 +19,7 @@ from .unet_blocks_flax import (
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@flax.struct.dataclass
|
||||
class FlaxUNet2DConditionOutput(BaseOutput):
|
||||
"""
|
||||
Args:
|
||||
@@ -39,10 +39,23 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library
|
||||
implements for all the models (such as downloading or saving, etc.)
|
||||
|
||||
Also, this model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
|
||||
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
|
||||
general usage and behavior.
|
||||
|
||||
Finally, this model supports inherent JAX features such as:
|
||||
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
||||
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
||||
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
||||
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`, *optional*): The size of the input sample.
|
||||
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
|
||||
sample_size (`int`, *optional*):
|
||||
The size of the input sample.
|
||||
in_channels (`int`, *optional*, defaults to 4):
|
||||
The number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 4):
|
||||
The number of channels in the output.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D",
|
||||
"FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D"
|
||||
@@ -51,10 +64,14 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
"FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D"
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
||||
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
||||
cross_attention_dim (`int`, *optional*, defaults to 768): The dimension of the cross attention features.
|
||||
dropout (`float`, *optional*, defaults to 0): Dropout probability for down, up and bottleneck blocks.
|
||||
layers_per_block (`int`, *optional*, defaults to 2):
|
||||
The number of layers per block.
|
||||
attention_head_dim (`int`, *optional*, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
cross_attention_dim (`int`, *optional*, defaults to 768):
|
||||
The dimension of the cross attention features.
|
||||
dropout (`float`, *optional*, defaults to 0):
|
||||
Dropout probability for down, up and bottleneck blocks.
|
||||
"""
|
||||
|
||||
sample_size: int = 32
|
||||
@@ -73,10 +90,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
cross_attention_dim: int = 1280
|
||||
dropout: float = 0.0
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
freq_shift: int = 0
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
|
||||
# init input tensors
|
||||
sample_shape = (1, self.sample_size, self.sample_size, self.in_channels)
|
||||
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
|
||||
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
|
||||
timesteps = jnp.ones((1,), dtype=jnp.int32)
|
||||
encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
|
||||
@@ -100,7 +118,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
# time
|
||||
self.time_proj = FlaxTimesteps(block_out_channels[0])
|
||||
self.time_proj = FlaxTimesteps(block_out_channels[0], freq_shift=self.config.freq_shift)
|
||||
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
|
||||
|
||||
# down
|
||||
@@ -214,10 +232,17 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
When returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
# 1. time
|
||||
if not isinstance(timesteps, jnp.ndarray):
|
||||
timesteps = jnp.array([timesteps], dtype=jnp.int32)
|
||||
elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps.astype(dtype=jnp.float32)
|
||||
timesteps = jnp.expand_dims(timesteps, 0)
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
t_emb = self.time_embedding(t_emb)
|
||||
|
||||
# 2. pre-process
|
||||
sample = jnp.transpose(sample, (0, 2, 3, 1))
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# 3. down
|
||||
@@ -251,6 +276,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = nn.silu(sample)
|
||||
sample = self.conv_out(sample)
|
||||
sample = jnp.transpose(sample, (0, 3, 1, 2))
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
@@ -527,6 +527,8 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
|
||||
raise ValueError(
|
||||
@@ -546,8 +548,22 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
output_states = ()
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn), hidden_states, encoder_hidden_states
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
@@ -609,11 +625,24 @@ class DownBlock2D(nn.Module):
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states, temb=None):
|
||||
output_states = ()
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
@@ -1072,6 +1101,8 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
|
||||
raise ValueError(
|
||||
@@ -1087,15 +1118,36 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
for attn in self.attentions:
|
||||
attn._set_attention_slice(slice_size)
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
res_hidden_states_tuple,
|
||||
temb=None,
|
||||
encoder_hidden_states=None,
|
||||
):
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn), hidden_states, encoder_hidden_states
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
@@ -1150,6 +1202,8 @@ class UpBlock2D(nn.Module):
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
||||
for resnet in self.resnets:
|
||||
# pop res hidden states
|
||||
@@ -1157,7 +1211,17 @@ class UpBlock2D(nn.Module):
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
|
||||
@@ -19,6 +19,26 @@ from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
|
||||
|
||||
|
||||
class FlaxCrossAttnDownBlock2D(nn.Module):
|
||||
r"""
|
||||
Cross Attention 2D Downsizing block - original architecture from Unet transformers:
|
||||
https://arxiv.org/abs/2103.06104
|
||||
|
||||
Parameters:
|
||||
in_channels (:obj:`int`):
|
||||
Input channels
|
||||
out_channels (:obj:`int`):
|
||||
Output channels
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
||||
Dropout rate
|
||||
num_layers (:obj:`int`, *optional*, defaults to 1):
|
||||
Number of attention blocks layers
|
||||
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
|
||||
Number of attention heads of each spatial transformer block
|
||||
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
|
||||
Whether to add downsampling layer before each final output
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
dropout: float = 0.0
|
||||
@@ -55,7 +75,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
|
||||
self.attentions = attentions
|
||||
|
||||
if self.add_downsample:
|
||||
self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
|
||||
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
|
||||
output_states = ()
|
||||
@@ -66,13 +86,30 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.add_downsample:
|
||||
hidden_states = self.downsample(hidden_states)
|
||||
hidden_states = self.downsamplers_0(hidden_states)
|
||||
output_states += (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class FlaxDownBlock2D(nn.Module):
|
||||
r"""
|
||||
Flax 2D downsizing block
|
||||
|
||||
Parameters:
|
||||
in_channels (:obj:`int`):
|
||||
Input channels
|
||||
out_channels (:obj:`int`):
|
||||
Output channels
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
||||
Dropout rate
|
||||
num_layers (:obj:`int`, *optional*, defaults to 1):
|
||||
Number of attention blocks layers
|
||||
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
|
||||
Whether to add downsampling layer before each final output
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
dropout: float = 0.0
|
||||
@@ -96,7 +133,7 @@ class FlaxDownBlock2D(nn.Module):
|
||||
self.resnets = resnets
|
||||
|
||||
if self.add_downsample:
|
||||
self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
|
||||
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, temb, deterministic=True):
|
||||
output_states = ()
|
||||
@@ -106,13 +143,33 @@ class FlaxDownBlock2D(nn.Module):
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.add_downsample:
|
||||
hidden_states = self.downsample(hidden_states)
|
||||
hidden_states = self.downsamplers_0(hidden_states)
|
||||
output_states += (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class FlaxCrossAttnUpBlock2D(nn.Module):
|
||||
r"""
|
||||
Cross Attention 2D Upsampling block - original architecture from Unet transformers:
|
||||
https://arxiv.org/abs/2103.06104
|
||||
|
||||
Parameters:
|
||||
in_channels (:obj:`int`):
|
||||
Input channels
|
||||
out_channels (:obj:`int`):
|
||||
Output channels
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
||||
Dropout rate
|
||||
num_layers (:obj:`int`, *optional*, defaults to 1):
|
||||
Number of attention blocks layers
|
||||
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
|
||||
Number of attention heads of each spatial transformer block
|
||||
add_upsample (:obj:`bool`, *optional*, defaults to `True`):
|
||||
Whether to add upsampling layer before each final output
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
prev_output_channel: int
|
||||
@@ -151,7 +208,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
|
||||
self.attentions = attentions
|
||||
|
||||
if self.add_upsample:
|
||||
self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
|
||||
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True):
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
@@ -164,12 +221,31 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
|
||||
|
||||
if self.add_upsample:
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
hidden_states = self.upsamplers_0(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxUpBlock2D(nn.Module):
|
||||
r"""
|
||||
Flax 2D upsampling block
|
||||
|
||||
Parameters:
|
||||
in_channels (:obj:`int`):
|
||||
Input channels
|
||||
out_channels (:obj:`int`):
|
||||
Output channels
|
||||
prev_output_channel (:obj:`int`):
|
||||
Output channels from the previous block
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
||||
Dropout rate
|
||||
num_layers (:obj:`int`, *optional*, defaults to 1):
|
||||
Number of attention blocks layers
|
||||
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
|
||||
Whether to add downsampling layer before each final output
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
prev_output_channel: int
|
||||
@@ -196,7 +272,7 @@ class FlaxUpBlock2D(nn.Module):
|
||||
self.resnets = resnets
|
||||
|
||||
if self.add_upsample:
|
||||
self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
|
||||
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True):
|
||||
for resnet in self.resnets:
|
||||
@@ -208,12 +284,27 @@ class FlaxUpBlock2D(nn.Module):
|
||||
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
|
||||
|
||||
if self.add_upsample:
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
hidden_states = self.upsamplers_0(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxUNetMidBlock2DCrossAttn(nn.Module):
|
||||
r"""
|
||||
Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104
|
||||
|
||||
Parameters:
|
||||
in_channels (:obj:`int`):
|
||||
Input channels
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
||||
Dropout rate
|
||||
num_layers (:obj:`int`, *optional*, defaults to 1):
|
||||
Number of attention blocks layers
|
||||
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
|
||||
Number of attention heads of each spatial transformer block
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
in_channels: int
|
||||
dropout: float = 0.0
|
||||
num_layers: int = 1
|
||||
|
||||
812
src/diffusers/models/vae_flax.py
Normal file
812
src/diffusers/models/vae_flax.py
Normal file
@@ -0,0 +1,812 @@
|
||||
# JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Tuple
|
||||
|
||||
import flax
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
|
||||
from ..configuration_utils import ConfigMixin, flax_register_to_config
|
||||
from ..modeling_flax_utils import FlaxModelMixin
|
||||
from ..utils import BaseOutput
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxDecoderOutput(BaseOutput):
|
||||
"""
|
||||
Output of decoding method.
|
||||
|
||||
Args:
|
||||
sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
|
||||
Decoded output sample of the model. Output of the last layer of the model.
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
|
||||
sample: jnp.ndarray
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxAutoencoderKLOutput(BaseOutput):
|
||||
"""
|
||||
Output of AutoencoderKL encoding method.
|
||||
|
||||
Args:
|
||||
latent_dist (`FlaxDiagonalGaussianDistribution`):
|
||||
Encoded outputs of `Encoder` represented as the mean and logvar of `FlaxDiagonalGaussianDistribution`.
|
||||
`FlaxDiagonalGaussianDistribution` allows for sampling latents from the distribution.
|
||||
"""
|
||||
|
||||
latent_dist: "FlaxDiagonalGaussianDistribution"
|
||||
|
||||
|
||||
class FlaxUpsample2D(nn.Module):
|
||||
"""
|
||||
Flax implementation of 2D Upsample layer
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
Input channels
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
|
||||
in_channels: int
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.conv = nn.Conv(
|
||||
self.in_channels,
|
||||
kernel_size=(3, 3),
|
||||
strides=(1, 1),
|
||||
padding=((1, 1), (1, 1)),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
batch, height, width, channels = hidden_states.shape
|
||||
hidden_states = jax.image.resize(
|
||||
hidden_states,
|
||||
shape=(batch, height * 2, width * 2, channels),
|
||||
method="nearest",
|
||||
)
|
||||
hidden_states = self.conv(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxDownsample2D(nn.Module):
|
||||
"""
|
||||
Flax implementation of 2D Downsample layer
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
Input channels
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
|
||||
in_channels: int
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.conv = nn.Conv(
|
||||
self.in_channels,
|
||||
kernel_size=(3, 3),
|
||||
strides=(2, 2),
|
||||
padding="VALID",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
|
||||
hidden_states = jnp.pad(hidden_states, pad_width=pad)
|
||||
hidden_states = self.conv(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxResnetBlock2D(nn.Module):
|
||||
"""
|
||||
Flax implementation of 2D Resnet Block.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
Input channels
|
||||
out_channels (`int`):
|
||||
Output channels
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
||||
Dropout rate
|
||||
use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`):
|
||||
Whether to use `nin_shortcut`. This activates a new layer inside ResNet block
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
|
||||
in_channels: int
|
||||
out_channels: int = None
|
||||
dropout: float = 0.0
|
||||
use_nin_shortcut: bool = None
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
out_channels = self.in_channels if self.out_channels is None else self.out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
||||
self.conv1 = nn.Conv(
|
||||
out_channels,
|
||||
kernel_size=(3, 3),
|
||||
strides=(1, 1),
|
||||
padding=((1, 1), (1, 1)),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
||||
self.dropout_layer = nn.Dropout(self.dropout)
|
||||
self.conv2 = nn.Conv(
|
||||
out_channels,
|
||||
kernel_size=(3, 3),
|
||||
strides=(1, 1),
|
||||
padding=((1, 1), (1, 1)),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
|
||||
|
||||
self.conv_shortcut = None
|
||||
if use_nin_shortcut:
|
||||
self.conv_shortcut = nn.Conv(
|
||||
out_channels,
|
||||
kernel_size=(1, 1),
|
||||
strides=(1, 1),
|
||||
padding="VALID",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states, deterministic=True):
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = nn.swish(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = nn.swish(hidden_states)
|
||||
hidden_states = self.dropout_layer(hidden_states, deterministic)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
residual = self.conv_shortcut(residual)
|
||||
|
||||
return hidden_states + residual
|
||||
|
||||
|
||||
class FlaxAttentionBlock(nn.Module):
|
||||
r"""
|
||||
Flax Convolutional based multi-head attention block for diffusion-based VAE.
|
||||
|
||||
Parameters:
|
||||
channels (:obj:`int`):
|
||||
Input channels
|
||||
num_head_channels (:obj:`int`, *optional*, defaults to `None`):
|
||||
Number of attention heads
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
|
||||
"""
|
||||
channels: int
|
||||
num_head_channels: int = None
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1
|
||||
|
||||
dense = partial(nn.Dense, self.channels, dtype=self.dtype)
|
||||
|
||||
self.group_norm = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
||||
self.query, self.key, self.value = dense(), dense(), dense()
|
||||
self.proj_attn = dense()
|
||||
|
||||
def transpose_for_scores(self, projection):
|
||||
new_projection_shape = projection.shape[:-1] + (self.num_heads, -1)
|
||||
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D)
|
||||
new_projection = projection.reshape(new_projection_shape)
|
||||
# (B, T, H, D) -> (B, H, T, D)
|
||||
new_projection = jnp.transpose(new_projection, (0, 2, 1, 3))
|
||||
return new_projection
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
residual = hidden_states
|
||||
batch, height, width, channels = hidden_states.shape
|
||||
|
||||
hidden_states = self.group_norm(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.reshape((batch, height * width, channels))
|
||||
|
||||
query = self.query(hidden_states)
|
||||
key = self.key(hidden_states)
|
||||
value = self.value(hidden_states)
|
||||
|
||||
# transpose
|
||||
query = self.transpose_for_scores(query)
|
||||
key = self.transpose_for_scores(key)
|
||||
value = self.transpose_for_scores(value)
|
||||
|
||||
# compute attentions
|
||||
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
|
||||
attn_weights = jnp.einsum("...qc,...kc->...qk", query * scale, key * scale)
|
||||
attn_weights = nn.softmax(attn_weights, axis=-1)
|
||||
|
||||
# attend to values
|
||||
hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights)
|
||||
|
||||
hidden_states = jnp.transpose(hidden_states, (0, 2, 1, 3))
|
||||
new_hidden_states_shape = hidden_states.shape[:-2] + (self.channels,)
|
||||
hidden_states = hidden_states.reshape(new_hidden_states_shape)
|
||||
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
hidden_states = hidden_states.reshape((batch, height, width, channels))
|
||||
hidden_states = hidden_states + residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxDownEncoderBlock2D(nn.Module):
|
||||
r"""
|
||||
Flax Resnet blocks-based Encoder block for diffusion-based VAE.
|
||||
|
||||
Parameters:
|
||||
in_channels (:obj:`int`):
|
||||
Input channels
|
||||
out_channels (:obj:`int`):
|
||||
Output channels
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
||||
Dropout rate
|
||||
num_layers (:obj:`int`, *optional*, defaults to 1):
|
||||
Number of Resnet layer block
|
||||
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
|
||||
Whether to add downsample layer
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
dropout: float = 0.0
|
||||
num_layers: int = 1
|
||||
add_downsample: bool = True
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
resnets = []
|
||||
for i in range(self.num_layers):
|
||||
in_channels = self.in_channels if i == 0 else self.out_channels
|
||||
|
||||
res_block = FlaxResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=self.out_channels,
|
||||
dropout=self.dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
resnets.append(res_block)
|
||||
self.resnets = resnets
|
||||
|
||||
if self.add_downsample:
|
||||
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, deterministic=True):
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states, deterministic=deterministic)
|
||||
|
||||
if self.add_downsample:
|
||||
hidden_states = self.downsamplers_0(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxUpEncoderBlock2D(nn.Module):
|
||||
r"""
|
||||
Flax Resnet blocks-based Encoder block for diffusion-based VAE.
|
||||
|
||||
Parameters:
|
||||
in_channels (:obj:`int`):
|
||||
Input channels
|
||||
out_channels (:obj:`int`):
|
||||
Output channels
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
||||
Dropout rate
|
||||
num_layers (:obj:`int`, *optional*, defaults to 1):
|
||||
Number of Resnet layer block
|
||||
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
|
||||
Whether to add downsample layer
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
dropout: float = 0.0
|
||||
num_layers: int = 1
|
||||
add_upsample: bool = True
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
resnets = []
|
||||
for i in range(self.num_layers):
|
||||
in_channels = self.in_channels if i == 0 else self.out_channels
|
||||
res_block = FlaxResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=self.out_channels,
|
||||
dropout=self.dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
resnets.append(res_block)
|
||||
|
||||
self.resnets = resnets
|
||||
|
||||
if self.add_upsample:
|
||||
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
|
||||
|
||||
def __call__(self, hidden_states, deterministic=True):
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states, deterministic=deterministic)
|
||||
|
||||
if self.add_upsample:
|
||||
hidden_states = self.upsamplers_0(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxUNetMidBlock2D(nn.Module):
|
||||
r"""
|
||||
Flax Unet Mid-Block module.
|
||||
|
||||
Parameters:
|
||||
in_channels (:obj:`int`):
|
||||
Input channels
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
||||
Dropout rate
|
||||
num_layers (:obj:`int`, *optional*, defaults to 1):
|
||||
Number of Resnet layer block
|
||||
attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`):
|
||||
Number of attention heads for each attention block
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
in_channels: int
|
||||
dropout: float = 0.0
|
||||
num_layers: int = 1
|
||||
attn_num_head_channels: int = 1
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
FlaxResnetBlock2D(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.in_channels,
|
||||
dropout=self.dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
]
|
||||
|
||||
attentions = []
|
||||
|
||||
for _ in range(self.num_layers):
|
||||
attn_block = FlaxAttentionBlock(
|
||||
channels=self.in_channels, num_head_channels=self.attn_num_head_channels, dtype=self.dtype
|
||||
)
|
||||
attentions.append(attn_block)
|
||||
|
||||
res_block = FlaxResnetBlock2D(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.in_channels,
|
||||
dropout=self.dropout,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
resnets.append(res_block)
|
||||
|
||||
self.resnets = resnets
|
||||
self.attentions = attentions
|
||||
|
||||
def __call__(self, hidden_states, deterministic=True):
|
||||
hidden_states = self.resnets[0](hidden_states, deterministic=deterministic)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
hidden_states = attn(hidden_states)
|
||||
hidden_states = resnet(hidden_states, deterministic=deterministic)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxEncoder(nn.Module):
|
||||
r"""
|
||||
Flax Implementation of VAE Encoder.
|
||||
|
||||
This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
|
||||
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
|
||||
general usage and behavior.
|
||||
|
||||
Finally, this model supports inherent JAX features such as:
|
||||
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
||||
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
||||
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
||||
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
||||
|
||||
Parameters:
|
||||
in_channels (:obj:`int`, *optional*, defaults to 3):
|
||||
Input channels
|
||||
out_channels (:obj:`int`, *optional*, defaults to 3):
|
||||
Output channels
|
||||
down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`):
|
||||
DownEncoder block type
|
||||
block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`):
|
||||
Tuple containing the number of output channels for each block
|
||||
layers_per_block (:obj:`int`, *optional*, defaults to `2`):
|
||||
Number of Resnet layer for each block
|
||||
norm_num_groups (:obj:`int`, *optional*, defaults to `2`):
|
||||
norm num group
|
||||
act_fn (:obj:`str`, *optional*, defaults to `silu`):
|
||||
Activation function
|
||||
double_z (:obj:`bool`, *optional*, defaults to `False`):
|
||||
Whether to double the last output channels
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
in_channels: int = 3
|
||||
out_channels: int = 3
|
||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
|
||||
block_out_channels: Tuple[int] = (64,)
|
||||
layers_per_block: int = 2
|
||||
norm_num_groups: int = 32
|
||||
act_fn: str = "silu"
|
||||
double_z: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
block_out_channels = self.block_out_channels
|
||||
# in
|
||||
self.conv_in = nn.Conv(
|
||||
block_out_channels[0],
|
||||
kernel_size=(3, 3),
|
||||
strides=(1, 1),
|
||||
padding=((1, 1), (1, 1)),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# downsampling
|
||||
down_blocks = []
|
||||
output_channel = block_out_channels[0]
|
||||
for i, _ in enumerate(self.down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = FlaxDownEncoderBlock2D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=self.layers_per_block,
|
||||
add_downsample=not is_final_block,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
down_blocks.append(down_block)
|
||||
self.down_blocks = down_blocks
|
||||
|
||||
# middle
|
||||
self.mid_block = FlaxUNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
|
||||
)
|
||||
|
||||
# end
|
||||
conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels
|
||||
self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
||||
self.conv_out = nn.Conv(
|
||||
conv_out_channels,
|
||||
kernel_size=(3, 3),
|
||||
strides=(1, 1),
|
||||
padding=((1, 1), (1, 1)),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def __call__(self, sample, deterministic: bool = True):
|
||||
# in
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# downsampling
|
||||
for block in self.down_blocks:
|
||||
sample = block(sample, deterministic=deterministic)
|
||||
|
||||
# middle
|
||||
sample = self.mid_block(sample, deterministic=deterministic)
|
||||
|
||||
# end
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = nn.swish(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class FlaxDecoder(nn.Module):
|
||||
r"""
|
||||
Flax Implementation of VAE Decoder.
|
||||
|
||||
This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
|
||||
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
|
||||
general usage and behavior.
|
||||
|
||||
Finally, this model supports inherent JAX features such as:
|
||||
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
||||
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
||||
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
||||
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
||||
|
||||
Parameters:
|
||||
in_channels (:obj:`int`, *optional*, defaults to 3):
|
||||
Input channels
|
||||
out_channels (:obj:`int`, *optional*, defaults to 3):
|
||||
Output channels
|
||||
up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`):
|
||||
UpDecoder block type
|
||||
block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`):
|
||||
Tuple containing the number of output channels for each block
|
||||
layers_per_block (:obj:`int`, *optional*, defaults to `2`):
|
||||
Number of Resnet layer for each block
|
||||
norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
|
||||
norm num group
|
||||
act_fn (:obj:`str`, *optional*, defaults to `silu`):
|
||||
Activation function
|
||||
double_z (:obj:`bool`, *optional*, defaults to `False`):
|
||||
Whether to double the last output channels
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
parameters `dtype`
|
||||
"""
|
||||
in_channels: int = 3
|
||||
out_channels: int = 3
|
||||
up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
|
||||
block_out_channels: int = (64,)
|
||||
layers_per_block: int = 2
|
||||
norm_num_groups: int = 32
|
||||
act_fn: str = "silu"
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
block_out_channels = self.block_out_channels
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = nn.Conv(
|
||||
block_out_channels[-1],
|
||||
kernel_size=(3, 3),
|
||||
strides=(1, 1),
|
||||
padding=((1, 1), (1, 1)),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# middle
|
||||
self.mid_block = FlaxUNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
|
||||
)
|
||||
|
||||
# upsampling
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
up_blocks = []
|
||||
for i, _ in enumerate(self.up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
up_block = FlaxUpEncoderBlock2D(
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=self.layers_per_block + 1,
|
||||
add_upsample=not is_final_block,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
self.up_blocks = up_blocks
|
||||
|
||||
# end
|
||||
self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
|
||||
self.conv_out = nn.Conv(
|
||||
self.out_channels,
|
||||
kernel_size=(3, 3),
|
||||
strides=(1, 1),
|
||||
padding=((1, 1), (1, 1)),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def __call__(self, sample, deterministic: bool = True):
|
||||
# z to block_in
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# middle
|
||||
sample = self.mid_block(sample, deterministic=deterministic)
|
||||
|
||||
# upsampling
|
||||
for block in self.up_blocks:
|
||||
sample = block(sample, deterministic=deterministic)
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = nn.swish(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class FlaxDiagonalGaussianDistribution(object):
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
# Last axis to account for channels-last
|
||||
self.mean, self.logvar = jnp.split(parameters, 2, axis=-1)
|
||||
self.logvar = jnp.clip(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = jnp.exp(0.5 * self.logvar)
|
||||
self.var = jnp.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = jnp.zeros_like(self.mean)
|
||||
|
||||
def sample(self, key):
|
||||
return self.mean + self.std * jax.random.normal(key, self.mean.shape)
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return jnp.array([0.0])
|
||||
|
||||
if other is None:
|
||||
return 0.5 * jnp.sum(self.mean**2 + self.var - 1.0 - self.logvar, axis=[1, 2, 3])
|
||||
|
||||
return 0.5 * jnp.sum(
|
||||
jnp.square(self.mean - other.mean) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
axis=[1, 2, 3],
|
||||
)
|
||||
|
||||
def nll(self, sample, axis=[1, 2, 3]):
|
||||
if self.deterministic:
|
||||
return jnp.array([0.0])
|
||||
|
||||
logtwopi = jnp.log(2.0 * jnp.pi)
|
||||
return 0.5 * jnp.sum(logtwopi + self.logvar + jnp.square(sample - self.mean) / self.var, axis=axis)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
@flax_register_to_config
|
||||
class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
r"""
|
||||
Flax Implementation of Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational
|
||||
Bayes by Diederik P. Kingma and Max Welling.
|
||||
|
||||
This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
|
||||
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
|
||||
general usage and behavior.
|
||||
|
||||
Finally, this model supports inherent JAX features such as:
|
||||
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
||||
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
||||
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
||||
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
||||
|
||||
Parameters:
|
||||
in_channels (:obj:`int`, *optional*, defaults to 3):
|
||||
Input channels
|
||||
out_channels (:obj:`int`, *optional*, defaults to 3):
|
||||
Output channels
|
||||
down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`):
|
||||
DownEncoder block type
|
||||
up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`):
|
||||
UpDecoder block type
|
||||
block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`):
|
||||
Tuple containing the number of output channels for each block
|
||||
layers_per_block (:obj:`int`, *optional*, defaults to `2`):
|
||||
Number of Resnet layer for each block
|
||||
act_fn (:obj:`str`, *optional*, defaults to `silu`):
|
||||
Activation function
|
||||
latent_channels (:obj:`int`, *optional*, defaults to `4`):
|
||||
Latent space channels
|
||||
norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
|
||||
Norm num group
|
||||
sample_size (:obj:`int`, *optional*, defaults to `32`):
|
||||
Sample input size
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
parameters `dtype`
|
||||
"""
|
||||
in_channels: int = 3
|
||||
out_channels: int = 3
|
||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
|
||||
up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
|
||||
block_out_channels: Tuple[int] = (64,)
|
||||
layers_per_block: int = 1
|
||||
act_fn: str = "silu"
|
||||
latent_channels: int = 4
|
||||
norm_num_groups: int = 32
|
||||
sample_size: int = 32
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.encoder = FlaxEncoder(
|
||||
in_channels=self.config.in_channels,
|
||||
out_channels=self.config.latent_channels,
|
||||
down_block_types=self.config.down_block_types,
|
||||
block_out_channels=self.config.block_out_channels,
|
||||
layers_per_block=self.config.layers_per_block,
|
||||
act_fn=self.config.act_fn,
|
||||
norm_num_groups=self.config.norm_num_groups,
|
||||
double_z=True,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.decoder = FlaxDecoder(
|
||||
in_channels=self.config.latent_channels,
|
||||
out_channels=self.config.out_channels,
|
||||
up_block_types=self.config.up_block_types,
|
||||
block_out_channels=self.config.block_out_channels,
|
||||
layers_per_block=self.config.layers_per_block,
|
||||
norm_num_groups=self.config.norm_num_groups,
|
||||
act_fn=self.config.act_fn,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.quant_conv = nn.Conv(
|
||||
2 * self.config.latent_channels,
|
||||
kernel_size=(1, 1),
|
||||
strides=(1, 1),
|
||||
padding="VALID",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.post_quant_conv = nn.Conv(
|
||||
self.config.latent_channels,
|
||||
kernel_size=(1, 1),
|
||||
strides=(1, 1),
|
||||
padding="VALID",
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
|
||||
# init input tensors
|
||||
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
|
||||
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
|
||||
|
||||
params_rng, dropout_rng, gaussian_rng = jax.random.split(rng, 3)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng, "gaussian": gaussian_rng}
|
||||
|
||||
return self.init(rngs, sample)["params"]
|
||||
|
||||
def encode(self, sample, deterministic: bool = True, return_dict: bool = True):
|
||||
sample = jnp.transpose(sample, (0, 2, 3, 1))
|
||||
|
||||
hidden_states = self.encoder(sample, deterministic=deterministic)
|
||||
moments = self.quant_conv(hidden_states)
|
||||
posterior = FlaxDiagonalGaussianDistribution(moments)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
|
||||
return FlaxAutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def decode(self, latents, deterministic: bool = True, return_dict: bool = True):
|
||||
if latents.shape[-1] != self.config.latent_channels:
|
||||
latents = jnp.transpose(latents, (0, 2, 3, 1))
|
||||
|
||||
hidden_states = self.post_quant_conv(latents)
|
||||
hidden_states = self.decoder(hidden_states, deterministic=deterministic)
|
||||
|
||||
hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2))
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states,)
|
||||
|
||||
return FlaxDecoderOutput(sample=hidden_states)
|
||||
|
||||
def __call__(self, sample, sample_posterior=False, deterministic: bool = True, return_dict: bool = True):
|
||||
posterior = self.encode(sample, deterministic=deterministic, return_dict=return_dict)
|
||||
if sample_posterior:
|
||||
rng = self.make_rng("gaussian")
|
||||
hidden_states = posterior.latent_dist.sample(rng)
|
||||
else:
|
||||
hidden_states = posterior.latent_dist.mode()
|
||||
|
||||
sample = self.decode(hidden_states, return_dict=return_dict).sample
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return FlaxDecoderOutput(sample=sample)
|
||||
@@ -24,16 +24,13 @@ import numpy as np
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from .utils import is_onnx_available, logging
|
||||
from .utils import ONNX_WEIGHTS_NAME, is_onnx_available, logging
|
||||
|
||||
|
||||
if is_onnx_available():
|
||||
import onnxruntime as ort
|
||||
|
||||
|
||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -49,7 +46,7 @@ class OnnxRuntimeModel:
|
||||
return self.model.run(None, inputs)
|
||||
|
||||
@staticmethod
|
||||
def load_model(path: Union[str, Path], provider=None):
|
||||
def load_model(path: Union[str, Path], provider=None, sess_options=None):
|
||||
"""
|
||||
Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
|
||||
|
||||
@@ -63,7 +60,7 @@ class OnnxRuntimeModel:
|
||||
logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
|
||||
provider = "CPUExecutionProvider"
|
||||
|
||||
return ort.InferenceSession(path, providers=[provider])
|
||||
return ort.InferenceSession(path, providers=[provider], sess_options=sess_options)
|
||||
|
||||
def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
@@ -117,6 +114,7 @@ class OnnxRuntimeModel:
|
||||
cache_dir: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
sess_options: Optional["ort.SessionOptions"] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -146,7 +144,9 @@ class OnnxRuntimeModel:
|
||||
model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
|
||||
# load model from local directory
|
||||
if os.path.isdir(model_id):
|
||||
model = OnnxRuntimeModel.load_model(os.path.join(model_id, model_file_name), provider=provider)
|
||||
model = OnnxRuntimeModel.load_model(
|
||||
os.path.join(model_id, model_file_name), provider=provider, sess_options=sess_options
|
||||
)
|
||||
kwargs["model_save_dir"] = Path(model_id)
|
||||
# load model from hub
|
||||
else:
|
||||
@@ -161,7 +161,7 @@ class OnnxRuntimeModel:
|
||||
)
|
||||
kwargs["model_save_dir"] = Path(model_cache_path).parent
|
||||
kwargs["latest_model_name"] = Path(model_cache_path).name
|
||||
model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider)
|
||||
model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider, sess_options=sess_options)
|
||||
return cls(model=model, **kwargs)
|
||||
|
||||
@classmethod
|
||||
|
||||
474
src/diffusers/pipeline_flax_utils.py
Normal file
474
src/diffusers/pipeline_flax_utils.py
Normal file
@@ -0,0 +1,474 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import flax
|
||||
import PIL
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
|
||||
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME, SchedulerMixin
|
||||
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import FlaxPreTrainedModel
|
||||
|
||||
INDEX_FILE = "diffusion_flax_model.bin"
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
LOADABLE_CLASSES = {
|
||||
"diffusers": {
|
||||
"FlaxModelMixin": ["save_pretrained", "from_pretrained"],
|
||||
"SchedulerMixin": ["save_config", "from_config"],
|
||||
"FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"],
|
||||
},
|
||||
"transformers": {
|
||||
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
|
||||
"PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
|
||||
"FlaxPreTrainedModel": ["save_pretrained", "from_pretrained"],
|
||||
"FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
|
||||
},
|
||||
}
|
||||
|
||||
ALL_IMPORTABLE_CLASSES = {}
|
||||
for library in LOADABLE_CLASSES:
|
||||
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
||||
|
||||
|
||||
class DummyChecker:
|
||||
def __init__(self):
|
||||
self.dummy = True
|
||||
|
||||
|
||||
def import_flax_or_no_model(module, class_name):
|
||||
try:
|
||||
# 1. First make sure that if a Flax object is present, import this one
|
||||
class_obj = getattr(module, "Flax" + class_name)
|
||||
except AttributeError:
|
||||
# 2. If this doesn't work, it's not a model and we don't append "Flax"
|
||||
class_obj = getattr(module, class_name)
|
||||
except AttributeError:
|
||||
raise ValueError(f"Neither Flax{class_name} nor {class_name} exist in {module}")
|
||||
|
||||
return class_obj
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxImagePipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for image pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
|
||||
|
||||
class FlaxDiffusionPipeline(ConfigMixin):
|
||||
r"""
|
||||
Base class for all models.
|
||||
|
||||
[`FlaxDiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion
|
||||
pipelines and handles methods for loading, downloading and saving models as well as a few methods common to all
|
||||
pipelines to:
|
||||
|
||||
- enabling/disabling the progress bar for the denoising iteration
|
||||
|
||||
Class attributes:
|
||||
|
||||
- **config_name** ([`str`]) -- name of the config file that will store the class and module names of all
|
||||
components of the diffusion pipeline.
|
||||
"""
|
||||
config_name = "model_index.json"
|
||||
|
||||
def register_modules(self, **kwargs):
|
||||
# import it here to avoid circular import
|
||||
from diffusers import pipelines
|
||||
|
||||
for name, module in kwargs.items():
|
||||
# retrieve library
|
||||
library = module.__module__.split(".")[0]
|
||||
|
||||
# check if the module is a pipeline module
|
||||
pipeline_dir = module.__module__.split(".")[-2]
|
||||
path = module.__module__.split(".")
|
||||
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
|
||||
|
||||
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
||||
# Or if it's a pipeline module, then the module is inside the pipeline
|
||||
# folder so we set the library to module name.
|
||||
if library not in LOADABLE_CLASSES or is_pipeline_module:
|
||||
library = pipeline_dir
|
||||
|
||||
# retrieve class_name
|
||||
class_name = module.__class__.__name__
|
||||
|
||||
register_dict = {name: (library, class_name)}
|
||||
|
||||
# save model index config
|
||||
self.register_to_config(**register_dict)
|
||||
|
||||
# set models
|
||||
setattr(self, name, module)
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union[Dict, FrozenDict]):
|
||||
# TODO: handle inference_state
|
||||
"""
|
||||
Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
|
||||
a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
|
||||
method. The pipeline can easily be re-loaded using the `[`~FlaxDiffusionPipeline.from_pretrained`]` class
|
||||
method.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to which to save. Will be created if it doesn't exist.
|
||||
"""
|
||||
self.save_config(save_directory)
|
||||
|
||||
model_index_dict = dict(self.config)
|
||||
model_index_dict.pop("_class_name")
|
||||
model_index_dict.pop("_diffusers_version")
|
||||
model_index_dict.pop("_module", None)
|
||||
|
||||
for pipeline_component_name in model_index_dict.keys():
|
||||
sub_model = getattr(self, pipeline_component_name)
|
||||
model_cls = sub_model.__class__
|
||||
|
||||
save_method_name = None
|
||||
# search for the model's base class in LOADABLE_CLASSES
|
||||
for library_name, library_classes in LOADABLE_CLASSES.items():
|
||||
library = importlib.import_module(library_name)
|
||||
for base_class, save_load_methods in library_classes.items():
|
||||
class_candidate = getattr(library, base_class)
|
||||
if issubclass(model_cls, class_candidate):
|
||||
# if we found a suitable base class in LOADABLE_CLASSES then grab its save method
|
||||
save_method_name = save_load_methods[0]
|
||||
break
|
||||
if save_method_name is not None:
|
||||
break
|
||||
|
||||
# TODO(Patrick, Suraj): to delete after
|
||||
if isinstance(sub_model, DummyChecker):
|
||||
continue
|
||||
|
||||
save_method = getattr(sub_model, save_method_name)
|
||||
expects_params = "params" in set(inspect.signature(save_method).parameters.keys())
|
||||
|
||||
if expects_params:
|
||||
save_method(
|
||||
os.path.join(save_directory, pipeline_component_name), params=params[pipeline_component_name]
|
||||
)
|
||||
else:
|
||||
save_method(os.path.join(save_directory, pipeline_component_name))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
r"""
|
||||
Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
|
||||
|
||||
The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
|
||||
|
||||
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
||||
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
||||
task.
|
||||
|
||||
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
||||
weights are discarded.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
|
||||
- A string, the *repo id* of a pretrained pipeline hosted inside a model repo on
|
||||
https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like
|
||||
`CompVis/ldm-text2im-large-256`.
|
||||
- A path to a *directory* containing pipeline weights saved using
|
||||
[`~FlaxDiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`.
|
||||
dtype (`str` or `jnp.dtype`, *optional*):
|
||||
Override the default `jnp.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
||||
will be automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
||||
file exists.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
local_files_only(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to only look at local files (i.e., do not try to download the model).
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
mirror (`str`, *optional*):
|
||||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
||||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||
Please refer to the mirror site for more information. specify the folder name here.
|
||||
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
|
||||
specific pipeline class. The overwritten components are then directly passed to the pipelines
|
||||
`__init__` method. See example below for more information.
|
||||
|
||||
<Tip>
|
||||
|
||||
Passing `use_auth_token=True`` is required when you want to use a private model, *e.g.*
|
||||
`"CompVis/stable-diffusion-v1-4"`
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
|
||||
this method in a firewalled environment.
|
||||
|
||||
</Tip>
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import FlaxDiffusionPipeline
|
||||
|
||||
>>> # Download pipeline from huggingface.co and cache.
|
||||
>>> pipeline = FlaxDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
||||
|
||||
>>> # Download pipeline that requires an authorization token
|
||||
>>> # For more information on access tokens, please refer to this section
|
||||
>>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
|
||||
>>> pipeline = FlaxDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
|
||||
|
||||
>>> # Download pipeline, but overwrite scheduler
|
||||
>>> from diffusers import LMSDiscreteScheduler
|
||||
|
||||
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
||||
>>> pipeline = FlaxDiffusionPipeline.from_pretrained(
|
||||
... "CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True
|
||||
... )
|
||||
```
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
from_pt = kwargs.pop("from_pt", False)
|
||||
dtype = kwargs.pop("dtype", None)
|
||||
|
||||
# 1. Download the checkpoints and configs
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
if not os.path.isdir(pretrained_model_name_or_path):
|
||||
config_dict = cls.get_config_dict(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
)
|
||||
# make sure we only download sub-folders and `diffusers` filenames
|
||||
folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
|
||||
allow_patterns = [os.path.join(k, "*") for k in folder_names]
|
||||
allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name]
|
||||
|
||||
# download all allow_patterns
|
||||
cached_folder = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
allow_patterns=allow_patterns,
|
||||
)
|
||||
else:
|
||||
cached_folder = pretrained_model_name_or_path
|
||||
|
||||
config_dict = cls.get_config_dict(cached_folder)
|
||||
|
||||
# 2. Load the pipeline class, if using custom module then load it from the hub
|
||||
# if we load from explicit class, let's use it
|
||||
if cls != FlaxDiffusionPipeline:
|
||||
pipeline_class = cls
|
||||
else:
|
||||
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
|
||||
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
|
||||
|
||||
# some modules can be passed directly to the init
|
||||
# in this case they are already instantiated in `kwargs`
|
||||
# extract them here
|
||||
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
|
||||
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
init_kwargs = {}
|
||||
|
||||
# inference_params
|
||||
params = {}
|
||||
|
||||
# import it here to avoid circular import
|
||||
from diffusers import pipelines
|
||||
|
||||
# 3. Load each module in the pipeline
|
||||
for name, (library_name, class_name) in init_dict.items():
|
||||
# TODO(Patrick, Suraj) - delete later
|
||||
if class_name == "DummyChecker":
|
||||
library_name = "stable_diffusion"
|
||||
class_name = "FlaxStableDiffusionSafetyChecker"
|
||||
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
loaded_sub_model = None
|
||||
|
||||
# if the model is in a pipeline module, then we load it from the pipeline
|
||||
if name in passed_class_obj:
|
||||
# 1. check that passed_class_obj has correct parent class
|
||||
if not is_pipeline_module:
|
||||
library = importlib.import_module(library_name)
|
||||
class_obj = getattr(library, class_name)
|
||||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
|
||||
|
||||
expected_class_obj = None
|
||||
for class_name, class_candidate in class_candidates.items():
|
||||
if issubclass(class_obj, class_candidate):
|
||||
expected_class_obj = class_candidate
|
||||
|
||||
if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
|
||||
raise ValueError(
|
||||
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
|
||||
f" {expected_class_obj}"
|
||||
)
|
||||
else:
|
||||
logger.warn(
|
||||
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
|
||||
" has the correct type"
|
||||
)
|
||||
|
||||
# set passed class object
|
||||
loaded_sub_model = passed_class_obj[name]
|
||||
elif is_pipeline_module:
|
||||
pipeline_module = getattr(pipelines, library_name)
|
||||
if from_pt:
|
||||
class_obj = import_flax_or_no_model(pipeline_module, class_name)
|
||||
else:
|
||||
class_obj = getattr(pipeline_module, class_name)
|
||||
|
||||
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
||||
else:
|
||||
# else we just import it from the library.
|
||||
library = importlib.import_module(library_name)
|
||||
if from_pt:
|
||||
class_obj = import_flax_or_no_model(library, class_name)
|
||||
else:
|
||||
class_obj = getattr(library, class_name)
|
||||
|
||||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
|
||||
|
||||
if loaded_sub_model is None:
|
||||
load_method_name = None
|
||||
for class_name, class_candidate in class_candidates.items():
|
||||
if issubclass(class_obj, class_candidate):
|
||||
load_method_name = importable_classes[class_name][1]
|
||||
|
||||
load_method = getattr(class_obj, load_method_name)
|
||||
|
||||
# check if the module is in a subdirectory
|
||||
if os.path.isdir(os.path.join(cached_folder, name)):
|
||||
loadable_folder = os.path.join(cached_folder, name)
|
||||
else:
|
||||
loaded_sub_model = cached_folder
|
||||
|
||||
if issubclass(class_obj, FlaxModelMixin):
|
||||
loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype)
|
||||
params[name] = loaded_params
|
||||
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
|
||||
# make sure we don't initialize the weights to save time
|
||||
if name == "safety_checker":
|
||||
loaded_sub_model = DummyChecker()
|
||||
loaded_params = {}
|
||||
elif from_pt:
|
||||
# TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here
|
||||
loaded_sub_model = load_method(loadable_folder, from_pt=from_pt)
|
||||
loaded_params = loaded_sub_model.params
|
||||
del loaded_sub_model._params
|
||||
else:
|
||||
loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False)
|
||||
params[name] = loaded_params
|
||||
elif issubclass(class_obj, SchedulerMixin):
|
||||
loaded_sub_model, scheduler_state = load_method(loadable_folder)
|
||||
params[name] = scheduler_state
|
||||
else:
|
||||
loaded_sub_model = load_method(loadable_folder)
|
||||
|
||||
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
||||
|
||||
model = pipeline_class(**init_kwargs, dtype=dtype)
|
||||
return model, params
|
||||
|
||||
@staticmethod
|
||||
def numpy_to_pil(images):
|
||||
"""
|
||||
Convert a numpy image or a batch of images to a PIL image.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
images = images[None, ...]
|
||||
images = (images * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
|
||||
return pil_images
|
||||
|
||||
# TODO: make it compatible with jax.lax
|
||||
def progress_bar(self, iterable):
|
||||
if not hasattr(self, "_progress_bar_config"):
|
||||
self._progress_bar_config = {}
|
||||
elif not isinstance(self._progress_bar_config, dict):
|
||||
raise ValueError(
|
||||
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
|
||||
)
|
||||
|
||||
return tqdm(iterable, **self._progress_bar_config)
|
||||
|
||||
def set_progress_bar_config(self, **kwargs):
|
||||
self._progress_bar_config = kwargs
|
||||
@@ -30,10 +30,8 @@ from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .modeling_utils import WEIGHTS_NAME
|
||||
from .onnx_utils import ONNX_WEIGHTS_NAME
|
||||
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, logging
|
||||
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
|
||||
|
||||
|
||||
INDEX_FILE = "diffusion_pytorch_model.bin"
|
||||
@@ -237,8 +235,8 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
|
||||
specific pipeline class. The overritten components are then directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
specific pipeline class. The overwritten components are then directly passed to the pipelines
|
||||
`__init__` method. See example below for more information.
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -284,6 +282,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
provider = kwargs.pop("provider", None)
|
||||
sess_options = kwargs.pop("sess_options", None)
|
||||
|
||||
# 1. Download the checkpoints and configs
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
@@ -341,6 +340,10 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
# 3. Load each module in the pipeline
|
||||
for name, (library_name, class_name) in init_dict.items():
|
||||
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
||||
if class_name.startswith("Flax"):
|
||||
class_name = class_name[4:]
|
||||
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
loaded_sub_model = None
|
||||
|
||||
@@ -396,6 +399,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
loading_kwargs["torch_dtype"] = torch_dtype
|
||||
if issubclass(class_obj, diffusers.OnnxRuntimeModel):
|
||||
loading_kwargs["provider"] = provider
|
||||
loading_kwargs["sess_options"] = sess_options
|
||||
|
||||
# check if the module is in a subdirectory
|
||||
if os.path.isdir(os.path.join(cached_folder, name)):
|
||||
|
||||
@@ -73,7 +73,7 @@ not be used for training. If you want to store the gradients during the forward
|
||||
We are more than happy about any contribution to the officially supported pipelines 🤗. We aspire
|
||||
all of our pipelines to be **self-contained**, **easy-to-tweak**, **beginner-friendly** and for **one-purpose-only**.
|
||||
|
||||
- **Self-contained**: A pipeline shall be as self-contained as possible. More specifically, this means that all functionality should be either directly defined in the pipeline file iteslf, should be inherited from (and only from) the [`DiffusionPipeline` class](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L56) or be directly attached to the model and scheduler components of the pipeline.
|
||||
- **Self-contained**: A pipeline shall be as self-contained as possible. More specifically, this means that all functionality should be either directly defined in the pipeline file itself, should be inherited from (and only from) the [`DiffusionPipeline` class](https://github.com/huggingface/diffusers/blob/5cbed8e0d157f65d3ddc2420dfd09f2df630e978/src/diffusers/pipeline_utils.py#L56) or be directly attached to the model and scheduler components of the pipeline.
|
||||
- **Easy-to-use**: Pipelines should be extremely easy to use - one should be able to load the pipeline and
|
||||
use it for its designated task, *e.g.* text-to-image generation, in just a couple of lines of code. Most
|
||||
logic including pre-processing, an unrolled diffusion loop, and post-processing should all happen inside the `__call__` method.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from ..utils import is_onnx_available, is_transformers_available
|
||||
from ..utils import is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
|
||||
from .ddim import DDIMPipeline
|
||||
from .ddpm import DDPMPipeline
|
||||
from .latent_diffusion_uncond import LDMPipeline
|
||||
@@ -7,7 +7,7 @@ from .score_sde_ve import ScoreSdeVePipeline
|
||||
from .stochastic_karras_ve import KarrasVePipeline
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
if is_torch_available() and is_transformers_available():
|
||||
from .latent_diffusion import LDMTextToImagePipeline
|
||||
from .stable_diffusion import (
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
@@ -17,3 +17,6 @@ if is_transformers_available():
|
||||
|
||||
if is_transformers_available() and is_onnx_available():
|
||||
from .stable_diffusion import StableDiffusionOnnxPipeline
|
||||
|
||||
if is_transformers_available() and is_flax_available():
|
||||
from .stable_diffusion import FlaxStableDiffusionPipeline
|
||||
|
||||
@@ -27,7 +27,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
||||
vqvae ([`VQModel`]):
|
||||
Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
|
||||
bert ([`LDMBertModel`]):
|
||||
Text-encoder model based on [BERT](ttps://huggingface.co/docs/transformers/model_doc/bert) architecture.
|
||||
Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture.
|
||||
tokenizer (`transformers.BertTokenizer`):
|
||||
Tokenizer of class
|
||||
[BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer).
|
||||
@@ -397,7 +397,7 @@ class LDMBertAttention(nn.Module):
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned aross GPUs when using tensor-parallelism.
|
||||
# partitioned across GPUs when using tensor-parallelism.
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
@@ -478,7 +478,7 @@ class LDMBertEncoderLayer(nn.Module):
|
||||
class LDMBertPreTrainedModel(PreTrainedModel):
|
||||
config_class = LDMBertConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_gradient_checkpointing = True
|
||||
_keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
||||
@@ -67,7 +67,7 @@ pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
with autocast("cuda"):
|
||||
image = pipe(prompt).sample[0]
|
||||
image = pipe(prompt).images[0]
|
||||
|
||||
image.save("astronaut_rides_horse.png")
|
||||
```
|
||||
@@ -89,7 +89,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
with autocast("cuda"):
|
||||
image = pipe(prompt).sample[0]
|
||||
image = pipe(prompt).images[0]
|
||||
|
||||
image.save("astronaut_rides_horse.png")
|
||||
```
|
||||
@@ -115,7 +115,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
with autocast("cuda"):
|
||||
image = pipe(prompt).sample[0]
|
||||
image = pipe(prompt).images[0]
|
||||
|
||||
image.save("astronaut_rides_horse.png")
|
||||
```
|
||||
|
||||
@@ -6,7 +6,7 @@ import numpy as np
|
||||
import PIL
|
||||
from PIL import Image
|
||||
|
||||
from ...utils import BaseOutput, is_onnx_available, is_transformers_available
|
||||
from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_transformers_available
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -35,3 +35,26 @@ if is_transformers_available():
|
||||
|
||||
if is_transformers_available() and is_onnx_available():
|
||||
from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline
|
||||
|
||||
if is_transformers_available() and is_flax_available():
|
||||
import flax
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxStableDiffusionPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Stable Diffusion pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
nsfw_content_detected (`List[bool]`)
|
||||
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
nsfw_content_detected: List[bool]
|
||||
|
||||
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
|
||||
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
|
||||
|
||||
@@ -0,0 +1,217 @@
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
|
||||
|
||||
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
|
||||
from ...pipeline_flax_utils import FlaxDiffusionPipeline
|
||||
from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler
|
||||
from . import FlaxStableDiffusionPipelineOutput
|
||||
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
|
||||
|
||||
|
||||
class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`FlaxAutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`FlaxCLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel),
|
||||
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`].
|
||||
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: FlaxAutoencoderKL,
|
||||
text_encoder: FlaxCLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: FlaxUNet2DConditionModel,
|
||||
scheduler: Union[FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler],
|
||||
safety_checker: FlaxStableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("np")
|
||||
self.dtype = dtype
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
def prepare_inputs(self, prompt: Union[str, List[str]]):
|
||||
if not isinstance(prompt, (str, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
text_input = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
return text_input.input_ids
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt_ids: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.PRNGKey,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 512,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
latents: Optional[jnp.array] = None,
|
||||
return_dict: bool = True,
|
||||
debug: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
latents (`jnp.array`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
|
||||
a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
|
||||
element is a list of `bool`s denoting whether the corresponding generated image likely represents
|
||||
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
# get prompt text embeddings
|
||||
text_embeddings = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
|
||||
|
||||
# TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
|
||||
# implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
|
||||
batch_size = prompt_ids.shape[0]
|
||||
|
||||
max_length = prompt_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0]
|
||||
context = jnp.concatenate([uncond_embeddings, text_embeddings])
|
||||
|
||||
latents_shape = (
|
||||
batch_size,
|
||||
self.unet.in_channels,
|
||||
self.unet.sample_size,
|
||||
self.unet.sample_size,
|
||||
)
|
||||
if latents is None:
|
||||
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
|
||||
def loop_body(step, args):
|
||||
latents, scheduler_state = args
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
latents_input = jnp.concatenate([latents] * 2)
|
||||
|
||||
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
|
||||
timestep = jnp.broadcast_to(t, latents_input.shape[0])
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet.apply(
|
||||
{"params": params["unet"]},
|
||||
jnp.array(latents_input),
|
||||
jnp.array(timestep, dtype=jnp.int32),
|
||||
encoder_hidden_states=context,
|
||||
).sample
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
|
||||
return latents, scheduler_state
|
||||
|
||||
scheduler_state = self.scheduler.set_timesteps(params["scheduler"], num_inference_steps=num_inference_steps)
|
||||
|
||||
if debug:
|
||||
# run with python for loop
|
||||
for i in range(num_inference_steps):
|
||||
latents, scheduler_state = loop_body(i, (latents, scheduler_state))
|
||||
else:
|
||||
latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state))
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
# TODO: check when flax vae gets merged into main
|
||||
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
|
||||
|
||||
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
|
||||
|
||||
# image = jnp.asarray(image).transpose(0, 2, 3, 1)
|
||||
# run safety checker
|
||||
# TODO: check when flax safety checker gets merged into main
|
||||
# safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
|
||||
# image, has_nsfw_concept = self.safety_checker(
|
||||
# images=image, clip_input=safety_checker_input.pixel_values, params=params["safety_params"]
|
||||
# )
|
||||
has_nsfw_concept = False
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
@@ -36,7 +36,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offsensive or harmful.
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
@@ -58,7 +58,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
warnings.warn(
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure "
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
||||
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
||||
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
||||
@@ -278,8 +278,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
# run safety checker
|
||||
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
@@ -48,7 +48,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offsensive or harmful.
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
@@ -70,7 +70,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
warnings.warn(
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure "
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
||||
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
||||
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
||||
@@ -213,7 +213,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to(self.device)
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
||||
|
||||
# get prompt text embeddings
|
||||
text_input = self.tokenizer(
|
||||
@@ -265,8 +265,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
sigma = self.scheduler.sigmas[t_index]
|
||||
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
latent_model_input = latent_model_input.to(self.unet.dtype)
|
||||
t = t.to(self.unet.dtype)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
@@ -284,14 +282,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents.to(self.vae.dtype)).sample
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
# run safety checker
|
||||
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
@@ -12,7 +12,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import DDIMScheduler, PNDMScheduler
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from ...utils import logging
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
@@ -66,7 +66,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offsensive or harmful.
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
@@ -78,7 +78,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler],
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
@@ -89,7 +89,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
warnings.warn(
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure "
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
||||
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
||||
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
||||
@@ -241,8 +241,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
timesteps = torch.tensor(
|
||||
[num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
|
||||
)
|
||||
else:
|
||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
|
||||
@@ -287,8 +292,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
latents = init_latents
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
|
||||
t_index = t_start + i
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
sigma = self.scheduler.sigmas[t_index]
|
||||
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
@@ -299,10 +309,15 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor(t_index))
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
|
||||
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
@@ -313,8 +328,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
# run safety checker
|
||||
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
@@ -48,20 +48,20 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
|
||||
# at the cost of increasing the possibility of filtering benign images
|
||||
adjustment = 0.0
|
||||
|
||||
for concet_idx in range(len(special_cos_dist[0])):
|
||||
concept_cos = special_cos_dist[i][concet_idx]
|
||||
concept_threshold = self.special_care_embeds_weights[concet_idx].item()
|
||||
result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
|
||||
if result_img["special_scores"][concet_idx] > 0:
|
||||
result_img["special_care"].append({concet_idx, result_img["special_scores"][concet_idx]})
|
||||
for concept_idx in range(len(special_cos_dist[0])):
|
||||
concept_cos = special_cos_dist[i][concept_idx]
|
||||
concept_threshold = self.special_care_embeds_weights[concept_idx].item()
|
||||
result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
|
||||
if result_img["special_scores"][concept_idx] > 0:
|
||||
result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
|
||||
adjustment = 0.01
|
||||
|
||||
for concet_idx in range(len(cos_dist[0])):
|
||||
concept_cos = cos_dist[i][concet_idx]
|
||||
concept_threshold = self.concept_embeds_weights[concet_idx].item()
|
||||
result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
|
||||
if result_img["concept_scores"][concet_idx] > 0:
|
||||
result_img["bad_concepts"].append(concet_idx)
|
||||
for concept_idx in range(len(cos_dist[0])):
|
||||
concept_cos = cos_dist[i][concept_idx]
|
||||
concept_threshold = self.concept_embeds_weights[concept_idx].item()
|
||||
result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
|
||||
if result_img["concept_scores"][concept_idx] > 0:
|
||||
result_img["bad_concepts"].append(concept_idx)
|
||||
|
||||
result.append(result_img)
|
||||
|
||||
|
||||
147
src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
Normal file
147
src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax import linen as nn
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from transformers import CLIPConfig, FlaxPreTrainedModel
|
||||
from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule
|
||||
|
||||
|
||||
def jax_cosine_distance(emb_1, emb_2, eps=1e-12):
|
||||
norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T
|
||||
norm_emb_2 = jnp.divide(emb_2.T, jnp.clip(jnp.linalg.norm(emb_2, axis=1), a_min=eps)).T
|
||||
return jnp.matmul(norm_emb_1, norm_emb_2.T)
|
||||
|
||||
|
||||
class FlaxStableDiffusionSafetyCheckerModule(nn.Module):
|
||||
config: CLIPConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.vision_model = FlaxCLIPVisionModule(self.config.vision_config)
|
||||
self.visual_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype)
|
||||
|
||||
self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.config.projection_dim))
|
||||
self.special_care_embeds = self.param(
|
||||
"special_care_embeds", jax.nn.initializers.ones, (3, self.config.projection_dim)
|
||||
)
|
||||
|
||||
self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,))
|
||||
self.special_care_embeds_weights = self.param("special_care_embeds_weights", jax.nn.initializers.ones, (3,))
|
||||
|
||||
def __call__(self, clip_input):
|
||||
pooled_output = self.vision_model(clip_input)[1]
|
||||
image_embeds = self.visual_projection(pooled_output)
|
||||
|
||||
special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds)
|
||||
cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds)
|
||||
return special_cos_dist, cos_dist
|
||||
|
||||
def filtered_with_scores(self, special_cos_dist, cos_dist, images):
|
||||
batch_size = special_cos_dist.shape[0]
|
||||
special_cos_dist = np.asarray(special_cos_dist)
|
||||
cos_dist = np.asarray(cos_dist)
|
||||
|
||||
result = []
|
||||
for i in range(batch_size):
|
||||
result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
|
||||
|
||||
# increase this value to create a stronger `nfsw` filter
|
||||
# at the cost of increasing the possibility of filtering benign image inputs
|
||||
adjustment = 0.0
|
||||
|
||||
for concept_idx in range(len(special_cos_dist[0])):
|
||||
concept_cos = special_cos_dist[i][concept_idx]
|
||||
concept_threshold = self.special_care_embeds_weights[concept_idx].item()
|
||||
result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
|
||||
if result_img["special_scores"][concept_idx] > 0:
|
||||
result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
|
||||
adjustment = 0.01
|
||||
|
||||
for concept_idx in range(len(cos_dist[0])):
|
||||
concept_cos = cos_dist[i][concept_idx]
|
||||
concept_threshold = self.concept_embeds_weights[concept_idx].item()
|
||||
result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
|
||||
if result_img["concept_scores"][concept_idx] > 0:
|
||||
result_img["bad_concepts"].append(concept_idx)
|
||||
|
||||
result.append(result_img)
|
||||
|
||||
has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
|
||||
|
||||
images_was_copied = False
|
||||
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
|
||||
if has_nsfw_concept:
|
||||
if not images_was_copied:
|
||||
images_was_copied = True
|
||||
images = images.copy()
|
||||
|
||||
images[idx] = np.zeros(images[idx].shape) # black image
|
||||
|
||||
if any(has_nsfw_concepts):
|
||||
warnings.warn(
|
||||
"Potential NSFW content was detected in one or more images. A black image will be returned"
|
||||
" instead. Try again with a different prompt and/or seed."
|
||||
)
|
||||
|
||||
return images, has_nsfw_concepts
|
||||
|
||||
|
||||
class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
|
||||
config_class = CLIPConfig
|
||||
main_input_name = "clip_input"
|
||||
module_class = FlaxStableDiffusionSafetyCheckerModule
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CLIPConfig,
|
||||
input_shape: Optional[Tuple] = None,
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
if input_shape is None:
|
||||
input_shape = (1, 224, 224, 3)
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensor
|
||||
clip_input = jax.random.normal(rng, input_shape)
|
||||
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
random_params = self.module.init(rngs, clip_input)["params"]
|
||||
|
||||
return random_params
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_input,
|
||||
params: dict = None,
|
||||
):
|
||||
clip_input = jnp.transpose(clip_input, (0, 2, 3, 1))
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
jnp.array(clip_input, dtype=jnp.float32),
|
||||
rngs={},
|
||||
)
|
||||
|
||||
def filtered_with_scores(self, special_cos_dist, cos_dist, images, params: dict = None):
|
||||
def _filtered_with_scores(module, special_cos_dist, cos_dist, images):
|
||||
return module.filtered_with_scores(special_cos_dist, cos_dist, images)
|
||||
|
||||
return self.module.apply(
|
||||
{"params": params or self.params},
|
||||
special_cos_dist,
|
||||
cos_dist,
|
||||
images,
|
||||
method=_filtered_with_scores,
|
||||
)
|
||||
@@ -17,13 +17,33 @@
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from ..utils import BaseOutput
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class DDIMSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's step function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
|
||||
`pred_original_sample` can be used to preview progress or for guidance.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
pred_original_sample: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
@@ -179,7 +199,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
use_clipped_model_output: bool = False,
|
||||
generator=None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
) -> Union[DDIMSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
@@ -192,11 +212,11 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
eta (`float`): weight of noise for added noise in diffusion step.
|
||||
use_clipped_model_output (`bool`): TODO
|
||||
generator: random number generator.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
||||
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
|
||||
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
@@ -222,6 +242,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# 2. compute alphas, betas
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
# 3. compute predicted original sample from predicted noise also called
|
||||
@@ -260,7 +281,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
|
||||
@@ -21,7 +21,6 @@ from typing import Optional, Tuple, Union
|
||||
|
||||
import flax
|
||||
import jax.numpy as jnp
|
||||
from jax import random
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
@@ -60,11 +59,12 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
|
||||
class DDIMSchedulerState:
|
||||
# setable values
|
||||
timesteps: jnp.ndarray
|
||||
alphas_cumprod: jnp.ndarray
|
||||
num_inference_steps: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, num_train_timesteps: int):
|
||||
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1])
|
||||
def create(cls, num_train_timesteps: int, alphas_cumprod: jnp.ndarray):
|
||||
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1], alphas_cumprod=alphas_cumprod)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -96,9 +96,19 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
clip_sample (`bool`, default `True`):
|
||||
option to clip predicted sample between -1 and 1 for numerical stability.
|
||||
set_alpha_to_one (`bool`, default `True`):
|
||||
if alpha for final step is 1 or the final alpha of the "non-previous" one.
|
||||
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
|
||||
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
||||
otherwise it uses the value of alpha at step 0.
|
||||
steps_offset (`int`, default `0`):
|
||||
an offset added to the inference steps. You can use a combination of `offset=1` and
|
||||
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
||||
stable diffusion.
|
||||
"""
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
return True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
@@ -106,12 +116,9 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[jnp.ndarray] = None,
|
||||
clip_sample: bool = True,
|
||||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = jnp.asarray(trained_betas)
|
||||
if beta_schedule == "linear":
|
||||
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
@@ -124,19 +131,24 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
|
||||
|
||||
# HACK for now - clean up later (PVP)
|
||||
self._alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
|
||||
|
||||
# At every step in ddim, we are looking into the previous alphas_cumprod
|
||||
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
||||
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
||||
# whether we use the final alpha of the "non-previous" one.
|
||||
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
||||
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else float(self._alphas_cumprod[0])
|
||||
|
||||
self.state = DDIMSchedulerState.create(num_train_timesteps=num_train_timesteps)
|
||||
def create_state(self):
|
||||
return DDIMSchedulerState.create(
|
||||
num_train_timesteps=self.config.num_train_timesteps, alphas_cumprod=self._alphas_cumprod
|
||||
)
|
||||
|
||||
def _get_variance(self, timestep, prev_timestep):
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
def _get_variance(self, timestep, prev_timestep, alphas_cumprod):
|
||||
alpha_prod_t = alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
@@ -144,9 +156,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return variance
|
||||
|
||||
def set_timesteps(
|
||||
self, state: DDIMSchedulerState, num_inference_steps: int, offset: int = 0
|
||||
) -> DDIMSchedulerState:
|
||||
def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int) -> DDIMSchedulerState:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
@@ -155,9 +165,9 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
the `FlaxDDIMScheduler` state data class instance.
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
offset (`int`):
|
||||
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
|
||||
"""
|
||||
offset = self.config.steps_offset
|
||||
|
||||
step_ratio = self.config.num_train_timesteps // num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
@@ -172,9 +182,6 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output: jnp.ndarray,
|
||||
timestep: int,
|
||||
sample: jnp.ndarray,
|
||||
key: random.KeyArray,
|
||||
eta: float = 0.0,
|
||||
use_clipped_model_output: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FlaxSchedulerOutput, Tuple]:
|
||||
"""
|
||||
@@ -213,44 +220,35 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# - pred_sample_direction -> "direction pointing to x_t"
|
||||
# - pred_prev_sample -> "x_t-1"
|
||||
|
||||
# TODO(Patrick) - eta is always 0.0 for now, allow to be set in step function
|
||||
eta = 0.0
|
||||
|
||||
# 1. get previous step value (=t-1)
|
||||
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
|
||||
|
||||
alphas_cumprod = state.alphas_cumprod
|
||||
|
||||
# 2. compute alphas, betas
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
alpha_prod_t = alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
|
||||
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
# 3. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
|
||||
# 4. Clip "predicted x_0"
|
||||
if self.config.clip_sample:
|
||||
pred_original_sample = jnp.clip(pred_original_sample, -1, 1)
|
||||
|
||||
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
||||
# 4. compute variance: "sigma_t(η)" -> see formula (16)
|
||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||
variance = self._get_variance(timestep, prev_timestep)
|
||||
variance = self._get_variance(timestep, prev_timestep, alphas_cumprod)
|
||||
std_dev_t = eta * variance ** (0.5)
|
||||
|
||||
if use_clipped_model_output:
|
||||
# the model_output is always re-derived from the clipped x_0 in Glide
|
||||
model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
||||
|
||||
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
|
||||
|
||||
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
# 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
||||
|
||||
if eta > 0:
|
||||
key = random.split(key, num=1)
|
||||
noise = random.normal(key=key, shape=model_output.shape)
|
||||
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
|
||||
|
||||
prev_sample = prev_sample + variance
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample, state)
|
||||
|
||||
@@ -263,9 +261,14 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
timesteps: jnp.ndarray,
|
||||
) -> jnp.ndarray:
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod[:, None]
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[:, None]
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -15,13 +15,33 @@
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from ..utils import BaseOutput
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class DDPMSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's step function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
|
||||
`pred_original_sample` can be used to preview progress or for guidance.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
pred_original_sample: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
@@ -177,7 +197,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
predict_epsilon=True,
|
||||
generator=None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
) -> Union[DDPMSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
@@ -190,11 +210,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
predict_epsilon (`bool`):
|
||||
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
|
||||
generator: random number generator.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
||||
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
[`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`:
|
||||
[`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
@@ -242,7 +262,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
if not return_dict:
|
||||
return (pred_prev_sample,)
|
||||
|
||||
return SchedulerOutput(prev_sample=pred_prev_sample)
|
||||
return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -266,9 +266,14 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
timesteps: jnp.ndarray,
|
||||
) -> jnp.ndarray:
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod[..., None]
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None]
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
@@ -35,10 +35,14 @@ class KarrasVeOutput(BaseOutput):
|
||||
denoising loop.
|
||||
derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Derivative of predicted original image sample (x_0).
|
||||
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
|
||||
`pred_original_sample` can be used to preview progress or for guidance.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
derivative: torch.FloatTensor
|
||||
pred_original_sample: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
@@ -106,7 +110,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
|
||||
self.schedule = [
|
||||
(
|
||||
self.config.sigma_max
|
||||
self.config.sigma_max**2
|
||||
* (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
|
||||
)
|
||||
for i in self.timesteps
|
||||
@@ -153,7 +157,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigma_hat (`float`): TODO
|
||||
sigma_prev (`float`): TODO
|
||||
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
|
||||
|
||||
KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check).
|
||||
Returns:
|
||||
@@ -170,7 +174,9 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
if not return_dict:
|
||||
return (sample_prev, derivative)
|
||||
|
||||
return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
|
||||
return KarrasVeOutput(
|
||||
prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
|
||||
)
|
||||
|
||||
def step_correct(
|
||||
self,
|
||||
@@ -192,7 +198,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
|
||||
sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
|
||||
derivative (`torch.FloatTensor` or `np.ndarray`): TODO
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
|
||||
|
||||
Returns:
|
||||
prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
|
||||
@@ -205,7 +211,9 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
if not return_dict:
|
||||
return (sample_prev, derivative)
|
||||
|
||||
return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
|
||||
return KarrasVeOutput(
|
||||
prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
|
||||
)
|
||||
|
||||
def add_noise(self, original_samples, noise, timesteps):
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -47,7 +47,7 @@ class FlaxKarrasVeOutput(BaseOutput):
|
||||
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
derivative (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Derivate of predicted original image sample (x_0).
|
||||
Derivative of predicted original image sample (x_0).
|
||||
state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
|
||||
"""
|
||||
|
||||
@@ -113,7 +113,7 @@ class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
timesteps = jnp.arange(0, num_inference_steps)[::-1].copy()
|
||||
schedule = [
|
||||
(
|
||||
self.config.sigma_max
|
||||
self.config.sigma_max**2
|
||||
* (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
|
||||
)
|
||||
for i in timesteps
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -20,7 +21,26 @@ import torch
|
||||
from scipy import integrate
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from ..utils import BaseOutput
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class LMSDiscreteSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's step function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
|
||||
`pred_original_sample` can be used to preview progress or for guidance.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
pred_original_sample: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
@@ -133,7 +153,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
order: int = 4,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
@@ -144,12 +164,12 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample (`torch.FloatTensor` or `np.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
order: coefficient for multi-step inference.
|
||||
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
||||
return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
||||
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
[`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] or `tuple`:
|
||||
[`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
|
||||
When returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
sigma = self.sigmas[timestep]
|
||||
@@ -175,7 +195,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
|
||||
@@ -198,8 +198,11 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise: jnp.ndarray,
|
||||
timesteps: jnp.ndarray,
|
||||
) -> jnp.ndarray:
|
||||
sigmas = self.match_shape(state.sigmas[timesteps], noise)
|
||||
noisy_samples = original_samples + noise * sigmas
|
||||
sigma = state.sigmas[timesteps].flatten()
|
||||
while len(sigma.shape) < len(noise.shape):
|
||||
sigma = sigma[..., None]
|
||||
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
|
||||
return noisy_samples
|
||||
|
||||
|
||||
@@ -12,9 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
@@ -59,7 +59,6 @@ class PNDMSchedulerState:
|
||||
# setable values
|
||||
_timesteps: jnp.ndarray
|
||||
num_inference_steps: Optional[int] = None
|
||||
_offset: int = 0
|
||||
prk_timesteps: Optional[jnp.ndarray] = None
|
||||
plms_timesteps: Optional[jnp.ndarray] = None
|
||||
timesteps: Optional[jnp.ndarray] = None
|
||||
@@ -104,8 +103,20 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
skip_prk_steps (`bool`):
|
||||
allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
|
||||
before plms steps; defaults to `False`.
|
||||
set_alpha_to_one (`bool`, default `False`):
|
||||
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
|
||||
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
||||
otherwise it uses the value of alpha at step 0.
|
||||
steps_offset (`int`, default `0`):
|
||||
an offset added to the inference steps. You can use a combination of `offset=1` and
|
||||
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
||||
stable diffusion.
|
||||
"""
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
return True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
@@ -115,6 +126,8 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[jnp.ndarray] = None,
|
||||
skip_prk_steps: bool = False,
|
||||
set_alpha_to_one: bool = False,
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = jnp.asarray(trained_betas)
|
||||
@@ -132,16 +145,17 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
|
||||
|
||||
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
||||
|
||||
# For now we only support F-PNDM, i.e. the runge-kutta method
|
||||
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
|
||||
# mainly at formula (9), (12), (13) and the Algorithm 2.
|
||||
self.pndm_order = 4
|
||||
|
||||
self.state = PNDMSchedulerState.create(num_train_timesteps=num_train_timesteps)
|
||||
def create_state(self):
|
||||
return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
|
||||
|
||||
def set_timesteps(
|
||||
self, state: PNDMSchedulerState, num_inference_steps: int, offset: int = 0
|
||||
) -> PNDMSchedulerState:
|
||||
def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> PNDMSchedulerState:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
@@ -150,16 +164,15 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
the `FlaxPNDMScheduler` state data class instance.
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
offset (`int`):
|
||||
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
|
||||
"""
|
||||
offset = self.config.steps_offset
|
||||
|
||||
step_ratio = self.config.num_train_timesteps // num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# rounding to avoid issues when num_inference_step is power of 3
|
||||
_timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1]
|
||||
_timesteps = _timesteps + offset
|
||||
_timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round() + offset
|
||||
|
||||
state = state.replace(num_inference_steps=num_inference_steps, _offset=offset, _timesteps=_timesteps)
|
||||
state = state.replace(num_inference_steps=num_inference_steps, _timesteps=_timesteps)
|
||||
|
||||
if self.config.skip_prk_steps:
|
||||
# for some models like stable diffusion the prk steps can/should be skipped to
|
||||
@@ -254,7 +267,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2
|
||||
prev_timestep = max(timestep - diff_to_prev, state.prk_timesteps[-1])
|
||||
prev_timestep = timestep - diff_to_prev
|
||||
timestep = state.prk_timesteps[state.counter // 4 * 4]
|
||||
|
||||
if state.counter % 4 == 0:
|
||||
@@ -274,7 +287,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# cur_sample should not be `None`
|
||||
cur_sample = state.cur_sample if state.cur_sample is not None else sample
|
||||
|
||||
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output, state=state)
|
||||
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
|
||||
state = state.replace(counter=state.counter + 1)
|
||||
|
||||
if not return_dict:
|
||||
@@ -320,7 +333,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"for more information."
|
||||
)
|
||||
|
||||
prev_timestep = max(timestep - self.config.num_train_timesteps // state.num_inference_steps, 0)
|
||||
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
|
||||
|
||||
if state.counter != 1:
|
||||
state = state.replace(ets=state.ets.append(model_output))
|
||||
@@ -344,7 +357,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]
|
||||
)
|
||||
|
||||
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output, state=state)
|
||||
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
|
||||
state = state.replace(counter=state.counter + 1)
|
||||
|
||||
if not return_dict:
|
||||
@@ -352,7 +365,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
|
||||
|
||||
def _get_prev_sample(self, sample, timestep, timestep_prev, model_output, state):
|
||||
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
|
||||
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
|
||||
# this function computes x_(t−δ) using the formula of (9)
|
||||
# Note that x_t needs to be added to both sides of the equation
|
||||
@@ -365,8 +378,8 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# sample -> x_t
|
||||
# model_output -> e_θ(x_t, t)
|
||||
# prev_sample -> x_(t−δ)
|
||||
alpha_prod_t = self.alphas_cumprod[timestep + 1 - state._offset]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - state._offset]
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
@@ -395,9 +408,14 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
timesteps: jnp.ndarray,
|
||||
) -> jnp.ndarray:
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod[..., None]
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None]
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
@@ -192,14 +192,17 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
|
||||
# also equation 47 shows the analog from SDE models to ancestral sampling methods
|
||||
drift = drift - diffusion[:, None, None, None] ** 2 * model_output
|
||||
diffusion = diffusion.flatten()
|
||||
while len(diffusion.shape) < len(sample.shape):
|
||||
diffusion = diffusion[:, None]
|
||||
drift = drift - diffusion**2 * model_output
|
||||
|
||||
# equation 6: sample noise for the diffusion term of
|
||||
key = random.split(key, num=1)
|
||||
noise = random.normal(key=key, shape=sample.shape)
|
||||
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
|
||||
# TODO is the variable diffusion the correct scaling term for the noise?
|
||||
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
|
||||
prev_sample = prev_sample_mean + diffusion * noise # add impact of diffusion field g
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample, prev_sample_mean, state)
|
||||
@@ -248,8 +251,11 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
step_size = step_size * jnp.ones(sample.shape[0])
|
||||
|
||||
# compute corrected sample: model_output term and noise term
|
||||
prev_sample_mean = sample + step_size[:, None, None, None] * model_output
|
||||
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
|
||||
step_size = step_size.flatten()
|
||||
while len(step_size.shape) < len(sample.shape):
|
||||
step_size = step_size[:, None]
|
||||
prev_sample_mean = sample + step_size * model_output
|
||||
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample, state)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import unittest
|
||||
from distutils.util import strtobool
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
@@ -92,3 +94,157 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
|
||||
image = PIL.ImageOps.exif_transpose(image)
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
|
||||
# --- pytest conf functions --- #
|
||||
|
||||
# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once
|
||||
pytest_opt_registered = {}
|
||||
|
||||
|
||||
def pytest_addoption_shared(parser):
|
||||
"""
|
||||
This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there.
|
||||
|
||||
It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest`
|
||||
option.
|
||||
|
||||
"""
|
||||
option = "--make-reports"
|
||||
if option not in pytest_opt_registered:
|
||||
parser.addoption(
|
||||
option,
|
||||
action="store",
|
||||
default=False,
|
||||
help="generate report files. The value of this option is used as a prefix to report names",
|
||||
)
|
||||
pytest_opt_registered[option] = 1
|
||||
|
||||
|
||||
def pytest_terminal_summary_main(tr, id):
|
||||
"""
|
||||
Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current
|
||||
directory. The report files are prefixed with the test suite name.
|
||||
|
||||
This function emulates --duration and -rA pytest arguments.
|
||||
|
||||
This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined
|
||||
there.
|
||||
|
||||
Args:
|
||||
- tr: `terminalreporter` passed from `conftest.py`
|
||||
- id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is
|
||||
needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other.
|
||||
|
||||
NB: this functions taps into a private _pytest API and while unlikely, it could break should
|
||||
pytest do internal changes - also it calls default internal methods of terminalreporter which
|
||||
can be hijacked by various `pytest-` plugins and interfere.
|
||||
|
||||
"""
|
||||
from _pytest.config import create_terminal_writer
|
||||
|
||||
if not len(id):
|
||||
id = "tests"
|
||||
|
||||
config = tr.config
|
||||
orig_writer = config.get_terminal_writer()
|
||||
orig_tbstyle = config.option.tbstyle
|
||||
orig_reportchars = tr.reportchars
|
||||
|
||||
dir = "reports"
|
||||
Path(dir).mkdir(parents=True, exist_ok=True)
|
||||
report_files = {
|
||||
k: f"{dir}/{id}_{k}.txt"
|
||||
for k in [
|
||||
"durations",
|
||||
"errors",
|
||||
"failures_long",
|
||||
"failures_short",
|
||||
"failures_line",
|
||||
"passes",
|
||||
"stats",
|
||||
"summary_short",
|
||||
"warnings",
|
||||
]
|
||||
}
|
||||
|
||||
# custom durations report
|
||||
# note: there is no need to call pytest --durations=XX to get this separate report
|
||||
# adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66
|
||||
dlist = []
|
||||
for replist in tr.stats.values():
|
||||
for rep in replist:
|
||||
if hasattr(rep, "duration"):
|
||||
dlist.append(rep)
|
||||
if dlist:
|
||||
dlist.sort(key=lambda x: x.duration, reverse=True)
|
||||
with open(report_files["durations"], "w") as f:
|
||||
durations_min = 0.05 # sec
|
||||
f.write("slowest durations\n")
|
||||
for i, rep in enumerate(dlist):
|
||||
if rep.duration < durations_min:
|
||||
f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted")
|
||||
break
|
||||
f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n")
|
||||
|
||||
def summary_failures_short(tr):
|
||||
# expecting that the reports were --tb=long (default) so we chop them off here to the last frame
|
||||
reports = tr.getreports("failed")
|
||||
if not reports:
|
||||
return
|
||||
tr.write_sep("=", "FAILURES SHORT STACK")
|
||||
for rep in reports:
|
||||
msg = tr._getfailureheadline(rep)
|
||||
tr.write_sep("_", msg, red=True, bold=True)
|
||||
# chop off the optional leading extra frames, leaving only the last one
|
||||
longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S)
|
||||
tr._tw.line(longrepr)
|
||||
# note: not printing out any rep.sections to keep the report short
|
||||
|
||||
# use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each
|
||||
# adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814
|
||||
# note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g.
|
||||
# pytest-instafail does that)
|
||||
|
||||
# report failures with line/short/long styles
|
||||
config.option.tbstyle = "auto" # full tb
|
||||
with open(report_files["failures_long"], "w") as f:
|
||||
tr._tw = create_terminal_writer(config, f)
|
||||
tr.summary_failures()
|
||||
|
||||
# config.option.tbstyle = "short" # short tb
|
||||
with open(report_files["failures_short"], "w") as f:
|
||||
tr._tw = create_terminal_writer(config, f)
|
||||
summary_failures_short(tr)
|
||||
|
||||
config.option.tbstyle = "line" # one line per error
|
||||
with open(report_files["failures_line"], "w") as f:
|
||||
tr._tw = create_terminal_writer(config, f)
|
||||
tr.summary_failures()
|
||||
|
||||
with open(report_files["errors"], "w") as f:
|
||||
tr._tw = create_terminal_writer(config, f)
|
||||
tr.summary_errors()
|
||||
|
||||
with open(report_files["warnings"], "w") as f:
|
||||
tr._tw = create_terminal_writer(config, f)
|
||||
tr.summary_warnings() # normal warnings
|
||||
tr.summary_warnings() # final warnings
|
||||
|
||||
tr.reportchars = "wPpsxXEf" # emulate -rA (used in summary_passes() and short_test_summary())
|
||||
with open(report_files["passes"], "w") as f:
|
||||
tr._tw = create_terminal_writer(config, f)
|
||||
tr.summary_passes()
|
||||
|
||||
with open(report_files["summary_short"], "w") as f:
|
||||
tr._tw = create_terminal_writer(config, f)
|
||||
tr.short_test_summary()
|
||||
|
||||
with open(report_files["stats"], "w") as f:
|
||||
tr._tw = create_terminal_writer(config, f)
|
||||
tr.summary_stats()
|
||||
|
||||
# restore:
|
||||
tr._tw = orig_writer
|
||||
tr.reportchars = orig_reportchars
|
||||
config.option.tbstyle = orig_tbstyle
|
||||
|
||||
@@ -47,6 +47,9 @@ default_cache_path = os.path.join(hf_cache_home, "diffusers")
|
||||
|
||||
|
||||
CONFIG_NAME = "config.json"
|
||||
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
|
||||
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
|
||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
|
||||
DIFFUSERS_CACHE = default_cache_path
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||
|
||||
11
src/diffusers/utils/dummy_flax_and_transformers_objects.py
Normal file
11
src/diffusers/utils/dummy_flax_and_transformers_objects.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
# flake8: noqa
|
||||
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class FlaxStableDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["flax", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax", "transformers"])
|
||||
@@ -11,6 +11,27 @@ class FlaxModelMixin(metaclass=DummyObject):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxUNet2DConditionModel(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxAutoencoderKL(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxDDIMScheduler(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
@@ -46,13 +67,6 @@ class FlaxPNDMScheduler(metaclass=DummyObject):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxUNet2DConditionModel(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxScoreSdeVeScheduler(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
|
||||
@@ -266,7 +266,7 @@ def reset_format() -> None:
|
||||
|
||||
def warning_advice(self, *args, **kwargs):
|
||||
"""
|
||||
This method is identical to `logger.warninging()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this
|
||||
This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this
|
||||
warning will not be printed
|
||||
"""
|
||||
no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False)
|
||||
|
||||
@@ -59,10 +59,17 @@ class BaseOutput(OrderedDict):
|
||||
if not len(class_fields):
|
||||
raise ValueError(f"{self.__class__.__name__} has no fields.")
|
||||
|
||||
for field in class_fields:
|
||||
v = getattr(self, field.name)
|
||||
if v is not None:
|
||||
self[field.name] = v
|
||||
first_field = getattr(self, class_fields[0].name)
|
||||
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
|
||||
|
||||
if other_fields_are_none and isinstance(first_field, dict):
|
||||
for key, value in first_field.items():
|
||||
self[key] = value
|
||||
else:
|
||||
for field in class_fields:
|
||||
v = getattr(self, field.name)
|
||||
if v is not None:
|
||||
self[field.name] = v
|
||||
|
||||
def __delitem__(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
44
tests/conftest.py
Normal file
44
tests/conftest.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# tests directory-specific settings - this file is run automatically
|
||||
# by pytest before any tests are run
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
from os.path import abspath, dirname, join
|
||||
|
||||
|
||||
# allow having multiple repository checkouts and not needing to remember to rerun
|
||||
# 'pip install -e .[dev]' when switching between checkouts and running tests.
|
||||
git_repo_path = abspath(join(dirname(dirname(__file__)), "src"))
|
||||
sys.path.insert(1, git_repo_path)
|
||||
|
||||
# silence FutureWarning warnings in tests since often we can't act on them until
|
||||
# they become normal warnings - i.e. the tests still need to test the current functionality
|
||||
warnings.simplefilter(action="ignore", category=FutureWarning)
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
from diffusers.testing_utils import pytest_addoption_shared
|
||||
|
||||
pytest_addoption_shared(parser)
|
||||
|
||||
|
||||
def pytest_terminal_summary(terminalreporter):
|
||||
from diffusers.testing_utils import pytest_terminal_summary_main
|
||||
|
||||
make_reports = terminalreporter.config.getoption("--make-reports")
|
||||
if make_reports:
|
||||
pytest_terminal_summary_main(terminalreporter, id=make_reports)
|
||||
@@ -246,3 +246,21 @@ class ModelTesterMixin:
|
||||
outputs_tuple = model(**inputs_dict, return_dict=False)
|
||||
|
||||
recursive_check(outputs_tuple, outputs_dict)
|
||||
|
||||
def test_enable_disable_gradient_checkpointing(self):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
return # Skip test if model does not support gradient checkpointing
|
||||
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
# at init model should have gradient checkpointing disabled
|
||||
model = self.model_class(**init_dict)
|
||||
self.assertFalse(model.is_gradient_checkpointing)
|
||||
|
||||
# check enable works
|
||||
model.enable_gradient_checkpointing()
|
||||
self.assertTrue(model.is_gradient_checkpointing)
|
||||
|
||||
# check disable works
|
||||
model.disable_gradient_checkpointing()
|
||||
self.assertFalse(model.is_gradient_checkpointing)
|
||||
|
||||
@@ -18,8 +18,8 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import UNet2DModel
|
||||
from diffusers.testing_utils import floats_tensor, torch_device
|
||||
from diffusers import UNet2DConditionModel, UNet2DModel
|
||||
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
||||
|
||||
from .test_modeling_common import ModelTesterMixin
|
||||
|
||||
@@ -159,6 +159,82 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
|
||||
class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DConditionModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_out_channels": (32, 64),
|
||||
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
|
||||
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
|
||||
"cross_attention_dim": 32,
|
||||
"attention_head_dim": 8,
|
||||
"out_channels": 4,
|
||||
"in_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"sample_size": 32,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
out = model(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model.zero_grad()
|
||||
out.sum().backward()
|
||||
|
||||
# now we save the output and parameter gradients that we will use for comparison purposes with
|
||||
# the non-checkpointed run.
|
||||
output_not_checkpointed = out.data.clone()
|
||||
grad_not_checkpointed = {}
|
||||
for name, param in model.named_parameters():
|
||||
grad_not_checkpointed[name] = param.grad.data.clone()
|
||||
|
||||
model.enable_gradient_checkpointing()
|
||||
out = model(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model.zero_grad()
|
||||
out.sum().backward()
|
||||
|
||||
# now we save the output and parameter gradients that we will use for comparison purposes with
|
||||
# the non-checkpointed run.
|
||||
output_checkpointed = out.data.clone()
|
||||
grad_checkpointed = {}
|
||||
for name, param in model.named_parameters():
|
||||
grad_checkpointed[name] = param.grad.data.clone()
|
||||
|
||||
# compare the output and parameters gradients
|
||||
self.assertTrue((output_checkpointed == output_not_checkpointed).all())
|
||||
for name in grad_checkpointed:
|
||||
self.assertTrue(torch.allclose(grad_checkpointed[name], grad_not_checkpointed[name], atol=5e-5))
|
||||
|
||||
|
||||
# TODO(Patrick) - Re-add this test after having cleaned up LDM
|
||||
# def test_output_pretrained_spatial_transformer(self):
|
||||
# model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy-spatial")
|
||||
@@ -231,6 +307,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@slow
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
@@ -244,6 +321,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@slow
|
||||
def test_output_pretrained_ve_mid(self):
|
||||
model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256")
|
||||
model.to(torch_device)
|
||||
|
||||
60
tests/test_outputs.py
Normal file
60
tests/test_outputs.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import PIL.Image
|
||||
from diffusers.utils.outputs import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomOutput(BaseOutput):
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
|
||||
|
||||
class ConfigTester(unittest.TestCase):
|
||||
def test_outputs_single_attribute(self):
|
||||
outputs = CustomOutput(images=np.random.rand(1, 3, 4, 4))
|
||||
|
||||
# check every way of getting the attribute
|
||||
assert isinstance(outputs.images, np.ndarray)
|
||||
assert outputs.images.shape == (1, 3, 4, 4)
|
||||
assert isinstance(outputs["images"], np.ndarray)
|
||||
assert outputs["images"].shape == (1, 3, 4, 4)
|
||||
assert isinstance(outputs[0], np.ndarray)
|
||||
assert outputs[0].shape == (1, 3, 4, 4)
|
||||
|
||||
# test with a non-tensor attribute
|
||||
outputs = CustomOutput(images=[PIL.Image.new("RGB", (4, 4))])
|
||||
|
||||
# check every way of getting the attribute
|
||||
assert isinstance(outputs.images, list)
|
||||
assert isinstance(outputs.images[0], PIL.Image.Image)
|
||||
assert isinstance(outputs["images"], list)
|
||||
assert isinstance(outputs["images"][0], PIL.Image.Image)
|
||||
assert isinstance(outputs[0], list)
|
||||
assert isinstance(outputs[0][0], PIL.Image.Image)
|
||||
|
||||
def test_outputs_dict_init(self):
|
||||
# test output reinitialization with a `dict` for compatibility with `accelerate`
|
||||
outputs = CustomOutput({"images": np.random.rand(1, 3, 4, 4)})
|
||||
|
||||
# check every way of getting the attribute
|
||||
assert isinstance(outputs.images, np.ndarray)
|
||||
assert outputs.images.shape == (1, 3, 4, 4)
|
||||
assert isinstance(outputs["images"], np.ndarray)
|
||||
assert outputs["images"].shape == (1, 3, 4, 4)
|
||||
assert isinstance(outputs[0], np.ndarray)
|
||||
assert outputs[0].shape == (1, 3, 4, 4)
|
||||
|
||||
# test with a non-tensor attribute
|
||||
outputs = CustomOutput({"images": [PIL.Image.new("RGB", (4, 4))]})
|
||||
|
||||
# check every way of getting the attribute
|
||||
assert isinstance(outputs.images, list)
|
||||
assert isinstance(outputs.images[0], PIL.Image.Image)
|
||||
assert isinstance(outputs["images"], list)
|
||||
assert isinstance(outputs["images"][0], PIL.Image.Image)
|
||||
assert isinstance(outputs[0], list)
|
||||
assert isinstance(outputs[0][0], PIL.Image.Image)
|
||||
@@ -46,11 +46,10 @@ from diffusers import (
|
||||
UNet2DModel,
|
||||
VQModel,
|
||||
)
|
||||
from diffusers.modeling_utils import WEIGHTS_NAME
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||
from diffusers.testing_utils import floats_tensor, load_image, slow, torch_device
|
||||
from diffusers.utils import CONFIG_NAME
|
||||
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
@@ -1105,7 +1104,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.26815, 0.1581, 0.2658, 0.23248, 0.1550, 0.2539, 0.1131, 0.1024, 0.0837])
|
||||
expected_slice = np.array([0.578, 0.5811, 0.5924, 0.5809, 0.587, 0.5886, 0.5861, 0.5802, 0.586])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
@@ -1325,14 +1324,58 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
assert image.shape == (512, 512, 3)
|
||||
assert np.abs(expected_image - image).max() < 1e-2
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
|
||||
def test_stable_diffusion_inpaint_pipeline_k_lms(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo.png"
|
||||
)
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
)
|
||||
expected_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/red_cat_sitting_on_a_park_bench_k_lms.png"
|
||||
)
|
||||
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
|
||||
|
||||
lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
model_id,
|
||||
scheduler=lms,
|
||||
safety_checker=self.dummy_safety_checker,
|
||||
use_auth_token=True,
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
prompt = "A red cat sitting on a park bench"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
init_image=init_image,
|
||||
mask_image=mask_image,
|
||||
strength=0.75,
|
||||
guidance_scale=7.5,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
assert np.abs(expected_image - image).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_stable_diffusion_onnx(self):
|
||||
from scripts.convert_stable_diffusion_checkpoint_to_onnx import convert_models
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
convert_models("CompVis/stable-diffusion-v1-4", tmpdirname, opset=14)
|
||||
|
||||
sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(tmpdirname, provider="CUDAExecutionProvider")
|
||||
sd_pipe = StableDiffusionOnnxPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CUDAExecutionProvider", use_auth_token=True
|
||||
)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
np.random.seed(0)
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
from tensorflow.core.protobuf.saved_model_pb2 import SavedModel
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_copies.py
|
||||
REPO_PATH = "."
|
||||
|
||||
# Internal TensorFlow ops that can be safely ignored (mostly specific to a saved model)
|
||||
INTERNAL_OPS = [
|
||||
"Assert",
|
||||
"AssignVariableOp",
|
||||
"EmptyTensorList",
|
||||
"MergeV2Checkpoints",
|
||||
"ReadVariableOp",
|
||||
"ResourceGather",
|
||||
"RestoreV2",
|
||||
"SaveV2",
|
||||
"ShardedFilename",
|
||||
"StatefulPartitionedCall",
|
||||
"StaticRegexFullMatch",
|
||||
"VarHandleOp",
|
||||
]
|
||||
|
||||
|
||||
def onnx_compliance(saved_model_path, strict, opset):
|
||||
saved_model = SavedModel()
|
||||
onnx_ops = []
|
||||
|
||||
with open(os.path.join(REPO_PATH, "utils", "tf_ops", "onnx.json")) as f:
|
||||
onnx_opsets = json.load(f)["opsets"]
|
||||
|
||||
for i in range(1, opset + 1):
|
||||
onnx_ops.extend(onnx_opsets[str(i)])
|
||||
|
||||
with open(saved_model_path, "rb") as f:
|
||||
saved_model.ParseFromString(f.read())
|
||||
|
||||
model_op_names = set()
|
||||
|
||||
# Iterate over every metagraph in case there is more than one (a saved model can contain multiple graphs)
|
||||
for meta_graph in saved_model.meta_graphs:
|
||||
# Add operations in the graph definition
|
||||
model_op_names.update(node.op for node in meta_graph.graph_def.node)
|
||||
|
||||
# Go through the functions in the graph definition
|
||||
for func in meta_graph.graph_def.library.function:
|
||||
# Add operations in each function
|
||||
model_op_names.update(node.op for node in func.node_def)
|
||||
|
||||
# Convert to list, sorted if you want
|
||||
model_op_names = sorted(model_op_names)
|
||||
incompatible_ops = []
|
||||
|
||||
for op in model_op_names:
|
||||
if op not in onnx_ops and op not in INTERNAL_OPS:
|
||||
incompatible_ops.append(op)
|
||||
|
||||
if strict and len(incompatible_ops) > 0:
|
||||
raise Exception(f"Found the following incompatible ops for the opset {opset}:\n" + incompatible_ops)
|
||||
elif len(incompatible_ops) > 0:
|
||||
print(f"Found the following incompatible ops for the opset {opset}:")
|
||||
print(*incompatible_ops, sep="\n")
|
||||
else:
|
||||
print(f"The saved model {saved_model_path} can properly be converted with ONNX.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--saved_model_path", help="Path of the saved model to check (the .pb file).")
|
||||
parser.add_argument(
|
||||
"--opset", default=12, type=int, help="The ONNX opset against which the model has to be tested."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--framework", choices=["onnx"], default="onnx", help="Frameworks against which to test the saved model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--strict", action="store_true", help="Whether make the checking strict (raise errors) or not (raise warnings)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.framework == "onnx":
|
||||
onnx_compliance(args.saved_model_path, args.strict, args.opset)
|
||||
34
utils/get_modified_files.py
Normal file
34
utils/get_modified_files.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# this script reports modified .py files under the desired list of top-level sub-dirs passed as a list of arguments, e.g.:
|
||||
# python ./utils/get_modified_files.py utils src tests examples
|
||||
#
|
||||
# it uses git to find the forking point and which files were modified - i.e. files not under git won't be considered
|
||||
# since the output of this script is fed into Makefile commands it doesn't print a newline after the results
|
||||
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
fork_point_sha = subprocess.check_output("git merge-base main HEAD".split()).decode("utf-8")
|
||||
modified_files = subprocess.check_output(f"git diff --name-only {fork_point_sha}".split()).decode("utf-8").split()
|
||||
|
||||
joined_dirs = "|".join(sys.argv[1:])
|
||||
regex = re.compile(rf"^({joined_dirs}).*?\.py$")
|
||||
|
||||
relevant_modified_files = [x for x in modified_files if regex.match(x)]
|
||||
print(" ".join(relevant_modified_files), end="")
|
||||
Reference in New Issue
Block a user