diff --git a/README.md b/README.md index f6889baf92..7f2704e5d6 100644 --- a/README.md +++ b/README.md @@ -226,6 +226,56 @@ image_pil = PIL.Image.fromarray(image_processed[0]) image_pil.save("test.png") ``` +#### **Example 1024x1024 image generation with SDE VE** + +See [paper](https://arxiv.org/abs/2011.13456) for more information on SDE VE. + +```python +from diffusers import DiffusionPipeline +import torch +import PIL.Image +import numpy as np + +torch.manual_seed(32) + +score_sde_sv = DiffusionPipeline.from_pretrained("fusing/ffhq_ncsnpp") + +# Note this might take up to 3 minutes on a GPU +image = score_sde_sv(num_inference_steps=2000) + +image = image.permute(0, 2, 3, 1).cpu().numpy() +image = np.clip(image * 255, 0, 255).astype(np.uint8) +image_pil = PIL.Image.fromarray(image[0]) + +# save image +image_pil.save("test.png") +``` +#### **Example 32x32 image generation with SDE VP** + +See [paper](https://arxiv.org/abs/2011.13456) for more information on SDE VE. + +```python +from diffusers import DiffusionPipeline +import torch +import PIL.Image +import numpy as np + +torch.manual_seed(32) + +score_sde_sv = DiffusionPipeline.from_pretrained("fusing/cifar10-ddpmpp-deep-vp") + +# Note this might take up to 3 minutes on a GPU +image = score_sde_sv(num_inference_steps=1000) + +image = image.permute(0, 2, 3, 1).cpu().numpy() +image = np.clip(image * 255, 0, 255).astype(np.uint8) +image_pil = PIL.Image.fromarray(image[0]) + +# save image +image_pil.save("test.png") +``` + + #### **Text to Image generation with Latent Diffusion** _Note: To use latent diffusion install transformers from [this branch](https://github.com/patil-suraj/transformers/tree/ldm-bert)._ @@ -249,24 +299,24 @@ image_pil = PIL.Image.fromarray(image_processed[0]) image_pil.save("test.png") ``` -#### **Text to speech with GradTTS and BDDM** +#### **Text to speech with GradTTS and BDDMPipeline** ```python import torch -from diffusers import BDDM, GradTTS +from diffusers import BDDMPipeline, GradTTSPipeline torch_device = "cuda" # load grad tts and bddm pipelines -grad_tts = GradTTS.from_pretrained("fusing/grad-tts-libri-tts") -bddm = BDDM.from_pretrained("fusing/diffwave-vocoder-ljspeech") +grad_tts = GradTTSPipeline.from_pretrained("fusing/grad-tts-libri-tts") +bddm = BDDMPipeline.from_pretrained("fusing/diffwave-vocoder-ljspeech") text = "Hello world, I missed you so much." # generate mel spectograms using text mel_spec = grad_tts(text, torch_device=torch_device) -# generate the speech by passing mel spectograms to BDDM pipeline +# generate the speech by passing mel spectograms to BDDMPipeline pipeline generator = torch.manual_seed(42) audio = bddm(mel_spec, generator, torch_device=torch_device) @@ -288,3 +338,4 @@ wavwrite("generated_audio.wav", sampling_rate, audio.squeeze().cpu().numpy()) - [ ] Add more vision models - [ ] Add more speech models - [ ] Add RL model +- [ ] Add FID and KID metrics diff --git a/run.py b/run.py new file mode 100755 index 0000000000..cae9713967 --- /dev/null +++ b/run.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +import numpy as np +import PIL +import torch +#from configs.ve import ffhq_ncsnpp_continuous as configs +# from configs.ve import cifar10_ncsnpp_continuous as configs + + +device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + +torch.backends.cuda.matmul.allow_tf32 = False +torch.manual_seed(0) + + +class NewReverseDiffusionPredictor: + def __init__(self, score_fn, probability_flow=False, sigma_min=0.0, sigma_max=0.0, N=0): + super().__init__() + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.N = N + self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N)) + + self.probability_flow = probability_flow + self.score_fn = score_fn + + def discretize(self, x, t): + timestep = (t * (self.N - 1)).long() + sigma = self.discrete_sigmas.to(t.device)[timestep] + adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), + self.discrete_sigmas[timestep - 1].to(t.device)) + f = torch.zeros_like(x) + G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2) + + labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t + result = self.score_fn(x, labels) + + rev_f = f - G[:, None, None, None] ** 2 * result * (0.5 if self.probability_flow else 1.) + rev_G = torch.zeros_like(G) if self.probability_flow else G + return rev_f, rev_G + + def update_fn(self, x, t): + f, G = self.discretize(x, t) + z = torch.randn_like(x) + x_mean = x - f + x = x_mean + G[:, None, None, None] * z + return x, x_mean + + +class NewLangevinCorrector: + def __init__(self, score_fn, snr, n_steps, sigma_min=0.0, sigma_max=0.0): + super().__init__() + self.score_fn = score_fn + self.snr = snr + self.n_steps = n_steps + + self.sigma_min = sigma_min + self.sigma_max = sigma_max + + def update_fn(self, x, t): + score_fn = self.score_fn + n_steps = self.n_steps + target_snr = self.snr +# if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE): +# timestep = (t * (sde.N - 1) / sde.T).long() +# alpha = sde.alphas.to(t.device)[timestep] +# else: + alpha = torch.ones_like(t) + + for i in range(n_steps): + labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t + grad = score_fn(x, labels) + noise = torch.randn_like(x) + grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean() + noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() + step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha + x_mean = x + step_size[:, None, None, None] * grad + x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise + + return x, x_mean + + + +def save_image(x): + image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8) + image_pil = PIL.Image.fromarray(image_processed[0]) + image_pil.save("../images/hey.png") + + +# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth" +#ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth" +# Note usually we need to restore ema etc... +# ema restored checkpoint used from below + +N = 2 +sigma_min = 0.01 +sigma_max = 1348 +sampling_eps = 1e-5 +batch_size = 1 +centered = False + +from diffusers import NCSNpp + +model = NCSNpp.from_pretrained("/home/patrick/ffhq_ncsnpp").to(device) +model = torch.nn.DataParallel(model) + +img_size = model.module.config.image_size +channels = model.module.config.num_channels +shape = (batch_size, channels, img_size, img_size) +probability_flow = False +snr = 0.15 +n_steps = 1 + + +new_corrector = NewLangevinCorrector(score_fn=model, snr=snr, n_steps=n_steps, sigma_min=sigma_min, sigma_max=sigma_max) +new_predictor = NewReverseDiffusionPredictor(score_fn=model, sigma_min=sigma_min, sigma_max=sigma_max, N=N) + +with torch.no_grad(): + # Initial sample + x = torch.randn(*shape) * sigma_max + x = x.to(device) + timesteps = torch.linspace(1, sampling_eps, N, device=device) + + for i in range(N): + t = timesteps[i] + vec_t = torch.ones(shape[0], device=t.device) * t + x, x_mean = new_corrector.update_fn(x, vec_t) + x, x_mean = new_predictor.update_fn(x, vec_t) + + x = x_mean + if centered: + x = (x + 1.) / 2. + + +# save_image(x) + +# for 5 cifar10 +x_sum = 106071.9922 +x_mean = 34.52864456176758 + +# for 1000 cifar10 +x_sum = 461.9700 +x_mean = 0.1504 + +# for 2 for 1024 +x_sum = 3382810112.0 +x_mean = 1075.366455078125 + +def check_x_sum_x_mean(x, x_sum, x_mean): + assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}" + assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" + + +check_x_sum_x_mean(x, x_sum, x_mean) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index efb89e8597..213b9a5bcc 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -7,23 +7,29 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode __version__ = "0.0.4" from .modeling_utils import ModelMixin -from .models.unet import UNetModel -from .models.unet_ldm import UNetLDMModel -from .models.unet_rl import TemporalUNet +from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel from .pipeline_utils import DiffusionPipeline -from .pipelines import BDDM, DDIM, DDPM, PNDM -from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin +from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline, ScoreSdeVePipeline, ScoreSdeVpPipeline +from .schedulers import ( + DDIMScheduler, + DDPMScheduler, + GradTTSScheduler, + PNDMScheduler, + SchedulerMixin, + ScoreSdeVeScheduler, + ScoreSdeVpScheduler, +) if is_transformers_available(): from .models.unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel from .models.unet_grad_tts import UNetGradTTSModel - from .pipelines import Glide, LatentDiffusion + from .pipelines import GlidePipeline, LatentDiffusionPipeline else: from .utils.dummy_transformers_objects import * if is_transformers_available() and is_inflect_available() and is_unidecode_available(): - from .pipelines import GradTTS + from .pipelines import GradTTSPipeline else: from .utils.dummy_transformers_and_inflect_and_unidecode_objects import * diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 3f0c78b3c6..71e321e111 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -21,3 +21,4 @@ from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, Glide from .unet_grad_tts import UNetGradTTSModel from .unet_ldm import UNetLDMModel from .unet_rl import TemporalUNet +from .unet_sde_score_estimation import NCSNpp diff --git a/src/diffusers/models/attention2d.py b/src/diffusers/models/attention2d.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py new file mode 100644 index 0000000000..f31b64ee5c --- /dev/null +++ b/src/diffusers/models/embeddings.py @@ -0,0 +1,85 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import numpy as np +import torch +from torch import nn + + +def get_timestep_embedding( + timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, scale=1, max_period=10000 +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + + emb_coeff = -math.log(max_period) / (half_dim - downscale_freq_shift) + emb = torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) + emb = torch.exp(emb * emb_coeff) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +# unet_sde_score_estimation.py +class GaussianFourierProjection(nn.Module): + """Gaussian Fourier embeddings for noise levels.""" + + def __init__(self, embedding_size=256, scale=1.0): + super().__init__() + self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + + def forward(self, x): + x_proj = x[:, None] * self.W[None, :] * 2 * np.pi + return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + + +# unet_rl.py - TODO(need test) +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py new file mode 100644 index 0000000000..04e3735d60 --- /dev/null +++ b/src/diffusers/models/resnet.py @@ -0,0 +1,278 @@ + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + +def conv_transpose_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.ConvTranspose1d(*args, **kwargs) + elif dims == 2: + return nn.ConvTranspose2d(*args, **kwargs) + elif dims == 3: + return nn.ConvTranspose3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +def nonlinearity(x, swish=1.0): + # swish + if swish == 1.0: + return F.silu(x) + else: + return x * F.sigmoid(x * float(swish)) + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, use_conv_transpose=False, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + self.use_conv_transpose = use_conv_transpose + + if use_conv_transpose: + self.conv = conv_transpose_nd(dims, channels, out_channels, 4, 2, 1) + elif use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(x) + + if self.dims == 3: + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") + else: + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + if self.use_conv: + x = self.conv(x) + + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + self.padding = padding + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.down = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + assert self.channels == self.out_channels + self.down = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.use_conv and self.padding == 0 and self.dims == 2: + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + return self.down(x) + + +class UNetUpsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + +class GlideUpsample(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class LDMUpsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class GradTTSUpsample(torch.nn.Module): + def __init__(self, dim): + super(Upsample, self).__init__() + self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Upsample1d(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) + + def forward(self, x): + return self.conv(x) + + +# class ResnetBlock(nn.Module): +# def __init__( +# self, +# *, +# in_channels, +# out_channels=None, +# conv_shortcut=False, +# dropout, +# temb_channels=512, +# use_scale_shift_norm=False, +# ): +# super().__init__() +# self.in_channels = in_channels +# out_channels = in_channels if out_channels is None else out_channels +# self.out_channels = out_channels +# self.use_conv_shortcut = conv_shortcut +# self.use_scale_shift_norm = use_scale_shift_norm + +# self.norm1 = Normalize(in_channels) +# self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + +# temp_out_channles = 2 * out_channels if use_scale_shift_norm else out_channels +# self.temb_proj = torch.nn.Linear(temb_channels, temp_out_channles) + +# self.norm2 = Normalize(out_channels) +# self.dropout = torch.nn.Dropout(dropout) +# self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) +# if self.in_channels != self.out_channels: +# if self.use_conv_shortcut: +# self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) +# else: +# self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + +# def forward(self, x, temb): +# h = x +# h = self.norm1(h) +# h = nonlinearity(h) +# h = self.conv1(h) + +# # TODO: check if this broadcasting works correctly for 1D and 3D +# temb = self.temb_proj(nonlinearity(temb))[:, :, None, None] + +# if self.use_scale_shift_norm: +# out_norm, out_rest = self.out_layers[0], self.out_layers[1:] +# scale, shift = torch.chunk(temb, 2, dim=1) +# h = self.norm2(h) * (1 + scale) + shift +# h = out_rest(h) +# else: +# h = h + temb +# h = self.norm2(h) +# h = nonlinearity(h) +# h = self.dropout(h) +# h = self.conv2(h) + +# if self.in_channels != self.out_channels: +# if self.use_conv_shortcut: +# x = self.conv_shortcut(x) +# else: +# x = self.nin_shortcut(x) + +# return x + h diff --git a/src/diffusers/models/unet.py b/src/diffusers/models/unet.py index a4e1e22df8..1749def9b1 100644 --- a/src/diffusers/models/unet.py +++ b/src/diffusers/models/unet.py @@ -30,27 +30,7 @@ from tqdm import tqdm from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin - - -def get_timestep_embedding(timesteps, embedding_dim): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: - From Fairseq. - Build sinusoidal embeddings. - This matches the implementation in tensor2tensor, but differs slightly - from the description in Section 3.5 of "Attention Is All You Need". - """ - assert len(timesteps.shape) == 1 - - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) - emb = emb.to(device=timesteps.device) - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb +from .embeddings import get_timestep_embedding def nonlinearity(x): diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index 648ff9c34a..c154db9210 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin +from .embeddings import get_timestep_embedding def convert_module_to_f16(l): @@ -86,27 +87,6 @@ def normalization(channels, swish=0.0): return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) -def timestep_embedding(timesteps, dim, max_period=10000): - """ - Create sinusoidal timestep embeddings. - - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( - device=timesteps.device - ) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - def zero_module(module): """ Zero out the parameters of a module and return it. @@ -627,7 +607,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin): """ hs = [] - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + emb = self.time_embed( + get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0) + ) h = x.type(self.dtype) for module in self.input_blocks: @@ -714,7 +696,9 @@ class GlideTextToImageUNetModel(GlideUNetModel): def forward(self, x, timesteps, transformer_out=None): hs = [] - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + emb = self.time_embed( + get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0) + ) # project the last token transformer_proj = self.transformer_proj(transformer_out[:, -1]) @@ -806,7 +790,9 @@ class GlideSuperResUNetModel(GlideUNetModel): x = torch.cat([x, upsampled], dim=1) hs = [] - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + emb = self.time_embed( + get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0) + ) h = x for module in self.input_blocks: diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index a2bdd951e4..36bcce53e9 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -1,5 +1,3 @@ -import math - import torch @@ -11,6 +9,7 @@ except: from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin +from .embeddings import get_timestep_embedding class Mish(torch.nn.Module): @@ -107,21 +106,6 @@ class Residual(torch.nn.Module): return output -class SinusoidalPosEmb(torch.nn.Module): - def __init__(self, dim): - super(SinusoidalPosEmb, self).__init__() - self.dim = dim - - def forward(self, x, scale=1000): - device = x.device - half_dim = self.dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) - emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb - - class UNetGradTTSModel(ModelMixin, ConfigMixin): def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000): super(UNetGradTTSModel, self).__init__() @@ -149,7 +133,6 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats) ) - self.time_pos_emb = SinusoidalPosEmb(dim) self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim)) dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)] @@ -198,7 +181,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): if not isinstance(spk, type(None)): s = self.spk_mlp(spk) - t = self.time_pos_emb(timesteps, scale=self.pe_scale) + t = get_timestep_embedding(timesteps, self.dim, scale=self.pe_scale) t = self.mlp(t) if self.n_spks < 2: diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index cca3231341..da84391a36 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -16,6 +16,7 @@ except: from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin +from .embeddings import get_timestep_embedding def exists(val): @@ -316,36 +317,6 @@ def normalization(channels, swish=0.0): return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) -def timestep_embedding(timesteps, dim, max_period=10000): - """ - Create sinusoidal timestep embeddings. - - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( - device=timesteps.device - ) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - ## go class AttentionPool2d(nn.Module): """ @@ -1026,7 +997,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): hs = [] if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device) - t_emb = timestep_embedding(timesteps, self.model_channels) + t_emb = get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0) emb = self.time_embed(t_emb) if self.num_classes is not None: @@ -1240,7 +1211,9 @@ class EncoderUNetModel(nn.Module): :param timesteps: a 1-D batch of timesteps. :return: an [N x K] Tensor of outputs. """ - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + emb = self.time_embed( + get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0) + ) results = [] h = x.type(self.dtype) diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index 55654dc62e..28fea5753c 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -13,7 +13,6 @@ except: print("Einops is not installed") pass - from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin @@ -107,14 +106,21 @@ class ResidualTemporalBlock(nn.Module): class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): def __init__( self, - horizon, + training_horizon, transition_dim, cond_dim, + predict_epsilon=False, + clip_denoised=True, dim=32, dim_mults=(1, 2, 4, 8), ): super().__init__() + self.transition_dim = transition_dim + self.cond_dim = cond_dim + self.predict_epsilon = predict_epsilon + self.clip_denoised = clip_denoised + dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) # print(f'[ models/temporal ] Channel dimensions: {in_out}') @@ -138,19 +144,19 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): self.downs.append( nn.ModuleList( [ - ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=horizon), - ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=horizon), + ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon), + ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon), Downsample1d(dim_out) if not is_last else nn.Identity(), ] ) ) if not is_last: - horizon = horizon // 2 + training_horizon = training_horizon // 2 mid_dim = dims[-1] - self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon) - self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon) + self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon) + self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): is_last = ind >= (num_resolutions - 1) @@ -158,15 +164,15 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): self.ups.append( nn.ModuleList( [ - ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon), - ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon), + ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon), + ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon), Upsample1d(dim_in) if not is_last else nn.Identity(), ] ) ) if not is_last: - horizon = horizon * 2 + training_horizon = training_horizon * 2 self.final_conv = nn.Sequential( Conv1dBlock(dim, dim, kernel_size=5), @@ -232,7 +238,6 @@ class TemporalValue(nn.Module): print(in_out) for dim_in, dim_out in in_out: - self.blocks.append( nn.ModuleList( [ diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py new file mode 100644 index 0000000000..83700c4b63 --- /dev/null +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -0,0 +1,1061 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# limitations under the License. + +# helpers functions + +import functools +import string + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..configuration_utils import ConfigMixin +from ..modeling_utils import ModelMixin +from .embeddings import GaussianFourierProjection, get_timestep_embedding + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[ + :, + max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), + :, + ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) + + +# Function ported from StyleGAN2 +def get_weight(module, shape, weight_var="weight", kernel_init=None): + """Get/create weight tensor for a convolution or fully-connected layer.""" + + return module.param(weight_var, kernel_init, shape) + + +class Conv2d(nn.Module): + """Conv2d layer with optimal upsampling and downsampling (StyleGAN2).""" + + def __init__( + self, + in_ch, + out_ch, + kernel, + up=False, + down=False, + resample_kernel=(1, 3, 3, 1), + use_bias=True, + kernel_init=None, + ): + super().__init__() + assert not (up and down) + assert kernel >= 1 and kernel % 2 == 1 + self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel)) + if kernel_init is not None: + self.weight.data = kernel_init(self.weight.data.shape) + if use_bias: + self.bias = nn.Parameter(torch.zeros(out_ch)) + + self.up = up + self.down = down + self.resample_kernel = resample_kernel + self.kernel = kernel + self.use_bias = use_bias + + def forward(self, x): + if self.up: + x = upsample_conv_2d(x, self.weight, k=self.resample_kernel) + elif self.down: + x = conv_downsample_2d(x, self.weight, k=self.resample_kernel) + else: + x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2) + + if self.use_bias: + x = x + self.bias.reshape(1, -1, 1, 1) + + return x + + +def naive_upsample_2d(x, factor=2): + _N, C, H, W = x.shape + x = torch.reshape(x, (-1, C, H, 1, W, 1)) + x = x.repeat(1, 1, 1, factor, 1, factor) + return torch.reshape(x, (-1, C, H * factor, W * factor)) + + +def naive_downsample_2d(x, factor=2): + _N, C, H, W = x.shape + x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor)) + return torch.mean(x, dim=(3, 5)) + + +def upsample_conv_2d(x, w, k=None, factor=2, gain=1): + """Fused `upsample_2d()` followed by `tf.nn.conv2d()`. + + Padding is performed only once at the beginning, not between the + operations. + The fused op is considerably more efficient than performing the same + calculation + using standard TensorFlow ops. It supports gradients of arbitrary order. + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + w: Weight tensor of the shape `[filterH, filterW, inChannels, + outChannels]`. Grouped convolution can be performed by `inChannels = + x.shape[0] // numGroups`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to + nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H * factor, W * factor]` or + `[N, H * factor, W * factor, C]`, and same datatype as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + + # Check weight shape. + assert len(w.shape) == 4 + convH = w.shape[2] + convW = w.shape[3] + inC = w.shape[1] + + assert convW == convH + + # Setup filter kernel. + if k is None: + k = [1] * factor + k = _setup_kernel(k) * (gain * (factor**2)) + p = (k.shape[0] - factor) - (convW - 1) + + stride = (factor, factor) + + # Determine data dimensions. + stride = [1, 1, factor, factor] + output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW) + output_padding = ( + output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH, + output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW, + ) + assert output_padding[0] >= 0 and output_padding[1] >= 0 + num_groups = _shape(x, 1) // inC + + # Transpose weights. + w = torch.reshape(w, (num_groups, -1, inC, convH, convW)) + w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4) + w = torch.reshape(w, (num_groups * inC, -1, convH, convW)) + + x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0) + # Original TF code. + # x = tf.nn.conv2d_transpose( + # x, + # w, + # output_shape=output_shape, + # strides=stride, + # padding='VALID', + # data_format=data_format) + # JAX equivalent + + return upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) + + +def conv_downsample_2d(x, w, k=None, factor=2, gain=1): + """Fused `tf.nn.conv2d()` followed by `downsample_2d()`. + + Padding is performed only once at the beginning, not between the operations. + The fused op is considerably more efficient than performing the same + calculation + using standard TensorFlow ops. It supports gradients of arbitrary order. + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + w: Weight tensor of the shape `[filterH, filterW, inChannels, + outChannels]`. Grouped convolution can be performed by `inChannels = + x.shape[0] // numGroups`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to + average pooling. + factor: Integer downsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H // factor, W // factor]` or + `[N, H // factor, W // factor, C]`, and same datatype as `x`. + """ + + assert isinstance(factor, int) and factor >= 1 + _outC, _inC, convH, convW = w.shape + assert convW == convH + if k is None: + k = [1] * factor + k = _setup_kernel(k) * gain + p = (k.shape[0] - factor) + (convW - 1) + s = [factor, factor] + x = upfirdn2d(x, torch.tensor(k, device=x.device), pad=((p + 1) // 2, p // 2)) + return F.conv2d(x, w, stride=s, padding=0) + + +def _setup_kernel(k): + k = np.asarray(k, dtype=np.float32) + if k.ndim == 1: + k = np.outer(k, k) + k /= np.sum(k) + assert k.ndim == 2 + assert k.shape[0] == k.shape[1] + return k + + +def _shape(x, dim): + return x.shape[dim] + + +def upsample_2d(x, k=None, factor=2, gain=1): + r"""Upsample a batch of 2D images with the given filter. + + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` + and upsamples each image with the given filter. The filter is normalized so + that + if the input pixels are constant, they will be scaled by the specified + `gain`. + Pixels outside the image are assumed to be zero, and the filter is padded + with + zeros so that its shape is a multiple of the upsampling factor. + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to + nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H * factor, W * factor]` + """ + assert isinstance(factor, int) and factor >= 1 + if k is None: + k = [1] * factor + k = _setup_kernel(k) * (gain * (factor**2)) + p = k.shape[0] - factor + return upfirdn2d(x, torch.tensor(k, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)) + + +def downsample_2d(x, k=None, factor=2, gain=1): + r"""Downsample a batch of 2D images with the given filter. + + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` + and downsamples each image with the given filter. The filter is normalized + so that + if the input pixels are constant, they will be scaled by the specified + `gain`. + Pixels outside the image are assumed to be zero, and the filter is padded + with + zeros so that its shape is a multiple of the downsampling factor. + Args: + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + C]`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to + average pooling. + factor: Integer downsampling factor (default: 2). + gain: Scaling factor for signal magnitude (default: 1.0). + + Returns: + Tensor of the shape `[N, C, H // factor, W // factor]` + """ + + assert isinstance(factor, int) and factor >= 1 + if k is None: + k = [1] * factor + k = _setup_kernel(k) * gain + p = k.shape[0] - factor + return upfirdn2d(x, torch.tensor(k, device=x.device), down=factor, pad=((p + 1) // 2, p // 2)) + + +def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1.0, padding=0): + """1x1 convolution with DDPM initialization.""" + conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias) + conv.weight.data = default_init(init_scale)(conv.weight.data.shape) + nn.init.zeros_(conv.bias) + return conv + + +def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1.0, padding=1): + """3x3 convolution with DDPM initialization.""" + conv = nn.Conv2d( + in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias + ) + conv.weight.data = default_init(init_scale)(conv.weight.data.shape) + nn.init.zeros_(conv.bias) + return conv + + +conv1x1 = ddpm_conv1x1 +conv3x3 = ddpm_conv3x3 + + +def _einsum(a, b, c, x, y): + einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c)) + return torch.einsum(einsum_str, x, y) + + +def contract_inner(x, y): + """tensordot(x, y, 1).""" + x_chars = list(string.ascii_lowercase[: len(x.shape)]) + y_chars = list(string.ascii_lowercase[len(x.shape) : len(y.shape) + len(x.shape)]) + y_chars[0] = x_chars[-1] # first axis of y and last of x get summed + out_chars = x_chars[:-1] + y_chars[1:] + return _einsum(x_chars, y_chars, out_chars, x, y) + + +class NIN(nn.Module): + def __init__(self, in_dim, num_units, init_scale=0.1): + super().__init__() + self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True) + self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True) + + def forward(self, x): + x = x.permute(0, 2, 3, 1) + y = contract_inner(x, self.W) + self.b + return y.permute(0, 3, 1, 2) + + +def get_act(nonlinearity): + """Get activation functions from the config file.""" + + if nonlinearity.lower() == "elu": + return nn.ELU() + elif nonlinearity.lower() == "relu": + return nn.ReLU() + elif nonlinearity.lower() == "lrelu": + return nn.LeakyReLU(negative_slope=0.2) + elif nonlinearity.lower() == "swish": + return nn.SiLU() + else: + raise NotImplementedError("activation function does not exist!") + + +def default_init(scale=1.0): + """The same initialization used in DDPM.""" + scale = 1e-10 if scale == 0 else scale + return variance_scaling(scale, "fan_avg", "uniform") + + +def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"): + """Ported from JAX.""" + + def _compute_fans(shape, in_axis=1, out_axis=0): + receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] + fan_in = shape[in_axis] * receptive_field_size + fan_out = shape[out_axis] * receptive_field_size + return fan_in, fan_out + + def init(shape, dtype=dtype, device=device): + fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) + if mode == "fan_in": + denominator = fan_in + elif mode == "fan_out": + denominator = fan_out + elif mode == "fan_avg": + denominator = (fan_in + fan_out) / 2 + else: + raise ValueError("invalid mode for variance scaling initializer: {}".format(mode)) + variance = scale / denominator + if distribution == "normal": + return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance) + elif distribution == "uniform": + return (torch.rand(*shape, dtype=dtype, device=device) * 2.0 - 1.0) * np.sqrt(3 * variance) + else: + raise ValueError("invalid distribution for variance scaling initializer") + + return init + + +class Combine(nn.Module): + """Combine information from skip connections.""" + + def __init__(self, dim1, dim2, method="cat"): + super().__init__() + self.Conv_0 = conv1x1(dim1, dim2) + self.method = method + + def forward(self, x, y): + h = self.Conv_0(x) + if self.method == "cat": + return torch.cat([h, y], dim=1) + elif self.method == "sum": + return h + y + else: + raise ValueError(f"Method {self.method} not recognized.") + + +class AttnBlockpp(nn.Module): + """Channel-wise self-attention block. Modified from DDPM.""" + + def __init__(self, channels, skip_rescale=False, init_scale=0.0): + super().__init__() + self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6) + self.NIN_0 = NIN(channels, channels) + self.NIN_1 = NIN(channels, channels) + self.NIN_2 = NIN(channels, channels) + self.NIN_3 = NIN(channels, channels, init_scale=init_scale) + self.skip_rescale = skip_rescale + + def forward(self, x): + B, C, H, W = x.shape + h = self.GroupNorm_0(x) + q = self.NIN_0(h) + k = self.NIN_1(h) + v = self.NIN_2(h) + + w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5)) + w = torch.reshape(w, (B, H, W, H * W)) + w = F.softmax(w, dim=-1) + w = torch.reshape(w, (B, H, W, H, W)) + h = torch.einsum("bhwij,bcij->bchw", w, v) + h = self.NIN_3(h) + if not self.skip_rescale: + return x + h + else: + return (x + h) / np.sqrt(2.0) + + +class Upsample(nn.Module): + def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)): + super().__init__() + out_ch = out_ch if out_ch else in_ch + if not fir: + if with_conv: + self.Conv_0 = conv3x3(in_ch, out_ch) + else: + if with_conv: + self.Conv2d_0 = Conv2d( + in_ch, + out_ch, + kernel=3, + up=True, + resample_kernel=fir_kernel, + use_bias=True, + kernel_init=default_init(), + ) + self.fir = fir + self.with_conv = with_conv + self.fir_kernel = fir_kernel + self.out_ch = out_ch + + def forward(self, x): + B, C, H, W = x.shape + if not self.fir: + h = F.interpolate(x, (H * 2, W * 2), "nearest") + if self.with_conv: + h = self.Conv_0(h) + else: + if not self.with_conv: + h = upsample_2d(x, self.fir_kernel, factor=2) + else: + h = self.Conv2d_0(x) + + return h + + +class Downsample(nn.Module): + def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, fir_kernel=(1, 3, 3, 1)): + super().__init__() + out_ch = out_ch if out_ch else in_ch + if not fir: + if with_conv: + self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0) + else: + if with_conv: + self.Conv2d_0 = Conv2d( + in_ch, + out_ch, + kernel=3, + down=True, + resample_kernel=fir_kernel, + use_bias=True, + kernel_init=default_init(), + ) + self.fir = fir + self.fir_kernel = fir_kernel + self.with_conv = with_conv + self.out_ch = out_ch + + def forward(self, x): + B, C, H, W = x.shape + if not self.fir: + if self.with_conv: + x = F.pad(x, (0, 1, 0, 1)) + x = self.Conv_0(x) + else: + x = F.avg_pool2d(x, 2, stride=2) + else: + if not self.with_conv: + x = downsample_2d(x, self.fir_kernel, factor=2) + else: + x = self.Conv2d_0(x) + + return x + + +class ResnetBlockDDPMpp(nn.Module): + """ResBlock adapted from DDPM.""" + + def __init__( + self, + act, + in_ch, + out_ch=None, + temb_dim=None, + conv_shortcut=False, + dropout=0.1, + skip_rescale=False, + init_scale=0.0, + ): + super().__init__() + out_ch = out_ch if out_ch else in_ch + self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) + self.Conv_0 = conv3x3(in_ch, out_ch) + if temb_dim is not None: + self.Dense_0 = nn.Linear(temb_dim, out_ch) + self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) + nn.init.zeros_(self.Dense_0.bias) + self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) + self.Dropout_0 = nn.Dropout(dropout) + self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) + if in_ch != out_ch: + if conv_shortcut: + self.Conv_2 = conv3x3(in_ch, out_ch) + else: + self.NIN_0 = NIN(in_ch, out_ch) + + self.skip_rescale = skip_rescale + self.act = act + self.out_ch = out_ch + self.conv_shortcut = conv_shortcut + + def forward(self, x, temb=None): + h = self.act(self.GroupNorm_0(x)) + h = self.Conv_0(h) + if temb is not None: + h += self.Dense_0(self.act(temb))[:, :, None, None] + h = self.act(self.GroupNorm_1(h)) + h = self.Dropout_0(h) + h = self.Conv_1(h) + if x.shape[1] != self.out_ch: + if self.conv_shortcut: + x = self.Conv_2(x) + else: + x = self.NIN_0(x) + if not self.skip_rescale: + return x + h + else: + return (x + h) / np.sqrt(2.0) + + +class ResnetBlockBigGANpp(nn.Module): + def __init__( + self, + act, + in_ch, + out_ch=None, + temb_dim=None, + up=False, + down=False, + dropout=0.1, + fir=False, + fir_kernel=(1, 3, 3, 1), + skip_rescale=True, + init_scale=0.0, + ): + super().__init__() + + out_ch = out_ch if out_ch else in_ch + self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) + self.up = up + self.down = down + self.fir = fir + self.fir_kernel = fir_kernel + + self.Conv_0 = conv3x3(in_ch, out_ch) + if temb_dim is not None: + self.Dense_0 = nn.Linear(temb_dim, out_ch) + self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape) + nn.init.zeros_(self.Dense_0.bias) + + self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) + self.Dropout_0 = nn.Dropout(dropout) + self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) + if in_ch != out_ch or up or down: + self.Conv_2 = conv1x1(in_ch, out_ch) + + self.skip_rescale = skip_rescale + self.act = act + self.in_ch = in_ch + self.out_ch = out_ch + + def forward(self, x, temb=None): + h = self.act(self.GroupNorm_0(x)) + + if self.up: + if self.fir: + h = upsample_2d(h, self.fir_kernel, factor=2) + x = upsample_2d(x, self.fir_kernel, factor=2) + else: + h = naive_upsample_2d(h, factor=2) + x = naive_upsample_2d(x, factor=2) + elif self.down: + if self.fir: + h = downsample_2d(h, self.fir_kernel, factor=2) + x = downsample_2d(x, self.fir_kernel, factor=2) + else: + h = naive_downsample_2d(h, factor=2) + x = naive_downsample_2d(x, factor=2) + + h = self.Conv_0(h) + # Add bias to each feature map conditioned on the time embedding + if temb is not None: + h += self.Dense_0(self.act(temb))[:, :, None, None] + h = self.act(self.GroupNorm_1(h)) + h = self.Dropout_0(h) + h = self.Conv_1(h) + + if self.in_ch != self.out_ch or self.up or self.down: + x = self.Conv_2(x) + + if not self.skip_rescale: + return x + h + else: + return (x + h) / np.sqrt(2.0) + + +class NCSNpp(ModelMixin, ConfigMixin): + """NCSN++ model""" + + def __init__( + self, + centered=False, + image_size=1024, + num_channels=3, + attention_type="ddpm", + attn_resolutions=(16,), + ch_mult=(1, 2, 4, 8, 16, 32, 32, 32), + conditional=True, + conv_size=3, + dropout=0.0, + embedding_type="fourier", + fir=True, + fir_kernel=(1, 3, 3, 1), + fourier_scale=16, + init_scale=0.0, + nf=16, + nonlinearity="swish", + normalization="GroupNorm", + num_res_blocks=1, + progressive="output_skip", + progressive_combine="sum", + progressive_input="input_skip", + resamp_with_conv=True, + resblock_type="biggan", + scale_by_sigma=True, + skip_rescale=True, + continuous=True, + ): + super().__init__() + self.register_to_config( + centered=centered, + image_size=image_size, + num_channels=num_channels, + attention_type=attention_type, + attn_resolutions=attn_resolutions, + ch_mult=ch_mult, + conditional=conditional, + conv_size=conv_size, + dropout=dropout, + embedding_type=embedding_type, + fir=fir, + fir_kernel=fir_kernel, + fourier_scale=fourier_scale, + init_scale=init_scale, + nf=nf, + nonlinearity=nonlinearity, + normalization=normalization, + num_res_blocks=num_res_blocks, + progressive=progressive, + progressive_combine=progressive_combine, + progressive_input=progressive_input, + resamp_with_conv=resamp_with_conv, + resblock_type=resblock_type, + scale_by_sigma=scale_by_sigma, + skip_rescale=skip_rescale, + continuous=continuous, + ) + self.act = act = get_act(nonlinearity) + + self.nf = nf + self.num_res_blocks = num_res_blocks + self.attn_resolutions = attn_resolutions + self.num_resolutions = len(ch_mult) + self.all_resolutions = all_resolutions = [image_size // (2**i) for i in range(self.num_resolutions)] + + self.conditional = conditional + self.skip_rescale = skip_rescale + self.resblock_type = resblock_type + self.progressive = progressive + self.progressive_input = progressive_input + self.embedding_type = embedding_type + assert progressive in ["none", "output_skip", "residual"] + assert progressive_input in ["none", "input_skip", "residual"] + assert embedding_type in ["fourier", "positional"] + combine_method = progressive_combine.lower() + combiner = functools.partial(Combine, method=combine_method) + + modules = [] + # timestep/noise_level embedding; only for continuous training + if embedding_type == "fourier": + # Gaussian Fourier features embeddings. + modules.append(GaussianFourierProjection(embedding_size=nf, scale=fourier_scale)) + embed_dim = 2 * nf + + elif embedding_type == "positional": + embed_dim = nf + + else: + raise ValueError(f"embedding type {embedding_type} unknown.") + + if conditional: + modules.append(nn.Linear(embed_dim, nf * 4)) + modules[-1].weight.data = default_init()(modules[-1].weight.shape) + nn.init.zeros_(modules[-1].bias) + modules.append(nn.Linear(nf * 4, nf * 4)) + modules[-1].weight.data = default_init()(modules[-1].weight.shape) + nn.init.zeros_(modules[-1].bias) + + AttnBlock = functools.partial(AttnBlockpp, init_scale=init_scale, skip_rescale=skip_rescale) + + Up_sample = functools.partial(Upsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) + + if progressive == "output_skip": + self.pyramid_upsample = Up_sample(fir=fir, fir_kernel=fir_kernel, with_conv=False) + elif progressive == "residual": + pyramid_upsample = functools.partial(Up_sample, fir=fir, fir_kernel=fir_kernel, with_conv=True) + + Down_sample = functools.partial(Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) + + if progressive_input == "input_skip": + self.pyramid_downsample = Down_sample(fir=fir, fir_kernel=fir_kernel, with_conv=False) + elif progressive_input == "residual": + pyramid_downsample = functools.partial(Down_sample, fir=fir, fir_kernel=fir_kernel, with_conv=True) + + if resblock_type == "ddpm": + ResnetBlock = functools.partial( + ResnetBlockDDPMpp, + act=act, + dropout=dropout, + init_scale=init_scale, + skip_rescale=skip_rescale, + temb_dim=nf * 4, + ) + + elif resblock_type == "biggan": + ResnetBlock = functools.partial( + ResnetBlockBigGANpp, + act=act, + dropout=dropout, + fir=fir, + fir_kernel=fir_kernel, + init_scale=init_scale, + skip_rescale=skip_rescale, + temb_dim=nf * 4, + ) + + else: + raise ValueError(f"resblock type {resblock_type} unrecognized.") + + # Downsampling block + + channels = num_channels + if progressive_input != "none": + input_pyramid_ch = channels + + modules.append(conv3x3(channels, nf)) + hs_c = [nf] + + in_ch = nf + for i_level in range(self.num_resolutions): + # Residual blocks for this resolution + for i_block in range(num_res_blocks): + out_ch = nf * ch_mult[i_level] + modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) + in_ch = out_ch + + if all_resolutions[i_level] in attn_resolutions: + modules.append(AttnBlock(channels=in_ch)) + hs_c.append(in_ch) + + if i_level != self.num_resolutions - 1: + if resblock_type == "ddpm": + modules.append(Downsample(in_ch=in_ch)) + else: + modules.append(ResnetBlock(down=True, in_ch=in_ch)) + + if progressive_input == "input_skip": + modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch)) + if combine_method == "cat": + in_ch *= 2 + + elif progressive_input == "residual": + modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch)) + input_pyramid_ch = in_ch + + hs_c.append(in_ch) + + in_ch = hs_c[-1] + modules.append(ResnetBlock(in_ch=in_ch)) + modules.append(AttnBlock(channels=in_ch)) + modules.append(ResnetBlock(in_ch=in_ch)) + + pyramid_ch = 0 + # Upsampling block + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(num_res_blocks + 1): + out_ch = nf * ch_mult[i_level] + modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch)) + in_ch = out_ch + + if all_resolutions[i_level] in attn_resolutions: + modules.append(AttnBlock(channels=in_ch)) + + if progressive != "none": + if i_level == self.num_resolutions - 1: + if progressive == "output_skip": + modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) + modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) + pyramid_ch = channels + elif progressive == "residual": + modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) + modules.append(conv3x3(in_ch, in_ch, bias=True)) + pyramid_ch = in_ch + else: + raise ValueError(f"{progressive} is not a valid name.") + else: + if progressive == "output_skip": + modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) + modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale)) + pyramid_ch = channels + elif progressive == "residual": + modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch)) + pyramid_ch = in_ch + else: + raise ValueError(f"{progressive} is not a valid name") + + if i_level != 0: + if resblock_type == "ddpm": + modules.append(Upsample(in_ch=in_ch)) + else: + modules.append(ResnetBlock(in_ch=in_ch, up=True)) + + assert not hs_c + + if progressive != "output_skip": + modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)) + modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) + + self.all_modules = nn.ModuleList(modules) + + def forward(self, x, time_cond, sigmas=None): + # timestep/noise_level embedding; only for continuous training + modules = self.all_modules + m_idx = 0 + if self.embedding_type == "fourier": + # Gaussian Fourier features embeddings. + used_sigmas = time_cond + temb = modules[m_idx](torch.log(used_sigmas)) + m_idx += 1 + + elif self.embedding_type == "positional": + # Sinusoidal positional embeddings. + timesteps = time_cond + used_sigmas = sigmas + temb = get_timestep_embedding(timesteps, self.nf) + + else: + raise ValueError(f"embedding type {self.embedding_type} unknown.") + + if self.conditional: + temb = modules[m_idx](temb) + m_idx += 1 + temb = modules[m_idx](self.act(temb)) + m_idx += 1 + else: + temb = None + + if not self.config.centered: + # If input data is in [0, 1] + x = 2 * x - 1.0 + + # Downsampling block + input_pyramid = None + if self.progressive_input != "none": + input_pyramid = x + + hs = [modules[m_idx](x)] + m_idx += 1 + for i_level in range(self.num_resolutions): + # Residual blocks for this resolution + for i_block in range(self.num_res_blocks): + h = modules[m_idx](hs[-1], temb) + m_idx += 1 + if h.shape[-1] in self.attn_resolutions: + h = modules[m_idx](h) + m_idx += 1 + + hs.append(h) + + if i_level != self.num_resolutions - 1: + if self.resblock_type == "ddpm": + h = modules[m_idx](hs[-1]) + m_idx += 1 + else: + h = modules[m_idx](hs[-1], temb) + m_idx += 1 + + if self.progressive_input == "input_skip": + input_pyramid = self.pyramid_downsample(input_pyramid) + h = modules[m_idx](input_pyramid, h) + m_idx += 1 + + elif self.progressive_input == "residual": + input_pyramid = modules[m_idx](input_pyramid) + m_idx += 1 + if self.skip_rescale: + input_pyramid = (input_pyramid + h) / np.sqrt(2.0) + else: + input_pyramid = input_pyramid + h + h = input_pyramid + + hs.append(h) + + h = hs[-1] + h = modules[m_idx](h, temb) + m_idx += 1 + h = modules[m_idx](h) + m_idx += 1 + h = modules[m_idx](h, temb) + m_idx += 1 + + pyramid = None + + # Upsampling block + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb) + m_idx += 1 + + if h.shape[-1] in self.attn_resolutions: + h = modules[m_idx](h) + m_idx += 1 + + if self.progressive != "none": + if i_level == self.num_resolutions - 1: + if self.progressive == "output_skip": + pyramid = self.act(modules[m_idx](h)) + m_idx += 1 + pyramid = modules[m_idx](pyramid) + m_idx += 1 + elif self.progressive == "residual": + pyramid = self.act(modules[m_idx](h)) + m_idx += 1 + pyramid = modules[m_idx](pyramid) + m_idx += 1 + else: + raise ValueError(f"{self.progressive} is not a valid name.") + else: + if self.progressive == "output_skip": + pyramid = self.pyramid_upsample(pyramid) + pyramid_h = self.act(modules[m_idx](h)) + m_idx += 1 + pyramid_h = modules[m_idx](pyramid_h) + m_idx += 1 + pyramid = pyramid + pyramid_h + elif self.progressive == "residual": + pyramid = modules[m_idx](pyramid) + m_idx += 1 + if self.skip_rescale: + pyramid = (pyramid + h) / np.sqrt(2.0) + else: + pyramid = pyramid + h + h = pyramid + else: + raise ValueError(f"{self.progressive} is not a valid name") + + if i_level != 0: + if self.resblock_type == "ddpm": + h = modules[m_idx](h) + m_idx += 1 + else: + h = modules[m_idx](h, temb) + m_idx += 1 + + assert not hs + + if self.progressive == "output_skip": + h = pyramid + else: + h = self.act(modules[m_idx](h)) + m_idx += 1 + h = modules[m_idx](h) + m_idx += 1 + + assert m_idx == len(modules) + if self.config.scale_by_sigma: + used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:])))) + h = h / used_sigmas + + return h diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index d8a2644dc9..d73b8d8fb3 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -21,7 +21,6 @@ from typing import Optional, Union from huggingface_hub import snapshot_download from .configuration_utils import ConfigMixin -from .dynamic_modules_utils import get_class_from_dynamic_module from .utils import DIFFUSERS_CACHE, logging @@ -81,16 +80,13 @@ class DiffusionPipeline(ConfigMixin): # set models setattr(self, name, module) - register_dict = {"_module": self.__module__.split(".")[-1]} - self.register_to_config(**register_dict) - def save_pretrained(self, save_directory: Union[str, os.PathLike]): self.save_config(save_directory) model_index_dict = dict(self.config) model_index_dict.pop("_class_name") model_index_dict.pop("_diffusers_version") - model_index_dict.pop("_module") + model_index_dict.pop("_module", None) for pipeline_component_name in model_index_dict.keys(): sub_model = getattr(self, pipeline_component_name) @@ -139,11 +135,7 @@ class DiffusionPipeline(ConfigMixin): config_dict = cls.get_config_dict(cached_folder) - # 2. Get class name and module candidates to load custom models - module_candidate_name = config_dict["_module"] - module_candidate = module_candidate_name + ".py" - - # 3. Load the pipeline class, if using custom module then load it from the hub + # 2. Load the pipeline class, if using custom module then load it from the hub # if we load from explicit class, let's use it if cls != DiffusionPipeline: pipeline_class = cls @@ -151,11 +143,6 @@ class DiffusionPipeline(ConfigMixin): diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) - # (TODO - we should allow to load custom pipelines - # else we need to load the correct module from the Hub - # module = module_candidate - # pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder) - init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) init_kwargs = {} @@ -163,7 +150,7 @@ class DiffusionPipeline(ConfigMixin): # import it here to avoid circular import from diffusers import pipelines - # 4. Load each module in the pipeline + # 3. Load each module in the pipeline for name, (library_name, class_name) in init_dict.items(): is_pipeline_module = hasattr(pipelines, library_name) # if the model is in a pipeline module, then we load it from the pipeline @@ -171,14 +158,7 @@ class DiffusionPipeline(ConfigMixin): pipeline_module = getattr(pipelines, library_name) class_obj = getattr(pipeline_module, class_name) importable_classes = ALL_IMPORTABLE_CLASSES - class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()} - elif library_name == module_candidate_name: - # if the model is not in diffusers or transformers, we need to load it from the hub - # assumes that it's a subclass of ModelMixin - class_obj = get_class_from_dynamic_module(cached_folder, module_candidate, class_name, cached_folder) - # since it's not from a library, we need to check class candidates for all importable classes - importable_classes = ALL_IMPORTABLE_CLASSES - class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()} + class_candidates = {c: class_obj for c in importable_classes.keys()} else: # else we just import it from the library. library = importlib.import_module(library_name) diff --git a/src/diffusers/pipelines/README.md b/src/diffusers/pipelines/README.md index 61e653a80f..c0558d35b9 100644 --- a/src/diffusers/pipelines/README.md +++ b/src/diffusers/pipelines/README.md @@ -15,5 +15,5 @@ TODO(Patrick, Anton, Suraj) - PNDM for unconditional image generation in [pipeline_pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py). - Latent diffusion for text to image generation / conditional image generation in [pipeline_latent_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_latent_diffusion.py). - Glide for text to image generation / conditional image generation in [pipeline_glide](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_glide.py). -- BDDM for spectrogram-to-sound vocoding in [pipeline_bddm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py). +- BDDMPipeline for spectrogram-to-sound vocoding in [pipeline_bddm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py). - Grad-TTS for text to audio generation / conditional audio generation in [pipeline_grad_tts](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_grad_tts.py). diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 7ba126b03b..5d7b1f14cf 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -1,14 +1,19 @@ from ..utils import is_inflect_available, is_transformers_available, is_unidecode_available -from .pipeline_bddm import BDDM -from .pipeline_ddim import DDIM -from .pipeline_ddpm import DDPM -from .pipeline_pndm import PNDM +from .pipeline_bddm import BDDMPipeline +from .pipeline_ddim import DDIMPipeline +from .pipeline_ddpm import DDPMPipeline +from .pipeline_pndm import PNDMPipeline +from .pipeline_score_sde_ve import ScoreSdeVePipeline +from .pipeline_score_sde_vp import ScoreSdeVpPipeline + + +# from .pipeline_score_sde import ScoreSdeVePipeline if is_transformers_available(): - from .pipeline_glide import Glide - from .pipeline_latent_diffusion import LatentDiffusion + from .pipeline_glide import GlidePipeline + from .pipeline_latent_diffusion import LatentDiffusionPipeline if is_transformers_available() and is_unidecode_available() and is_inflect_available(): - from .pipeline_grad_tts import GradTTS + from .pipeline_grad_tts import GradTTSPipeline diff --git a/src/diffusers/pipelines/pipeline_bddm.py b/src/diffusers/pipelines/pipeline_bddm.py index 3ca79c3dee..8b24cb9ceb 100644 --- a/src/diffusers/pipelines/pipeline_bddm.py +++ b/src/diffusers/pipelines/pipeline_bddm.py @@ -271,7 +271,7 @@ class DiffWave(ModelMixin, ConfigMixin): return self.final_conv(x) -class BDDM(DiffusionPipeline): +class BDDMPipeline(DiffusionPipeline): def __init__(self, diffwave, noise_scheduler): super().__init__() noise_scheduler = noise_scheduler.set_format("pt") diff --git a/src/diffusers/pipelines/pipeline_ddim.py b/src/diffusers/pipelines/pipeline_ddim.py index 272d3edb6b..8da24dbf8f 100644 --- a/src/diffusers/pipelines/pipeline_ddim.py +++ b/src/diffusers/pipelines/pipeline_ddim.py @@ -21,7 +21,7 @@ import tqdm from ..pipeline_utils import DiffusionPipeline -class DDIM(DiffusionPipeline): +class DDIMPipeline(DiffusionPipeline): def __init__(self, unet, noise_scheduler): super().__init__() noise_scheduler = noise_scheduler.set_format("pt") diff --git a/src/diffusers/pipelines/pipeline_ddpm.py b/src/diffusers/pipelines/pipeline_ddpm.py index ebcce77337..9cf83bfb75 100644 --- a/src/diffusers/pipelines/pipeline_ddpm.py +++ b/src/diffusers/pipelines/pipeline_ddpm.py @@ -21,7 +21,7 @@ import tqdm from ..pipeline_utils import DiffusionPipeline -class DDPM(DiffusionPipeline): +class DDPMPipeline(DiffusionPipeline): def __init__(self, unet, noise_scheduler): super().__init__() noise_scheduler = noise_scheduler.set_format("pt") diff --git a/src/diffusers/pipelines/pipeline_glide.py b/src/diffusers/pipelines/pipeline_glide.py index 2a6c073ec2..8680b7542a 100644 --- a/src/diffusers/pipelines/pipeline_glide.py +++ b/src/diffusers/pipelines/pipeline_glide.py @@ -695,7 +695,23 @@ class CLIPTextModel(CLIPPreTrainedModel): ##################### -class Glide(DiffusionPipeline): +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + torch.zeros(broadcast_shape, device=timesteps.device) + + +class GlidePipeline(DiffusionPipeline): def __init__( self, text_unet: GlideTextToImageUNetModel, diff --git a/src/diffusers/pipelines/pipeline_grad_tts.py b/src/diffusers/pipelines/pipeline_grad_tts.py index 4201124923..743104e658 100644 --- a/src/diffusers/pipelines/pipeline_grad_tts.py +++ b/src/diffusers/pipelines/pipeline_grad_tts.py @@ -420,7 +420,7 @@ class TextEncoder(ModelMixin, ConfigMixin): return mu, logw, x_mask -class GradTTS(DiffusionPipeline): +class GradTTSPipeline(DiffusionPipeline): def __init__(self, unet, text_encoder, noise_scheduler, tokenizer): super().__init__() noise_scheduler = noise_scheduler.set_format("pt") @@ -430,7 +430,14 @@ class GradTTS(DiffusionPipeline): @torch.no_grad() def __call__( - self, text, num_inference_steps=50, temperature=1.3, length_scale=0.91, speaker_id=15, torch_device=None + self, + text, + num_inference_steps=50, + temperature=1.3, + length_scale=0.91, + speaker_id=15, + torch_device=None, + generator=None, ): if torch_device is None: torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -464,17 +471,19 @@ class GradTTS(DiffusionPipeline): mu_y = mu_y.transpose(1, 2) # Sample latent representation from terminal distribution N(mu_y, I) - z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature + z = mu_y + torch.randn(mu_y.shape, generator=generator).to(mu_y.device) xt = z * y_mask h = 1.0 / num_inference_steps + # (Patrick: TODO) for t in tqdm.tqdm(range(num_inference_steps), total=num_inference_steps): + t_new = num_inference_steps - t - 1 t = (1.0 - (t + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device) - time = t.unsqueeze(-1).unsqueeze(-1) residual = self.unet(xt, t, mu_y, y_mask, speaker_id) - xt = self.noise_scheduler.step(xt, residual, mu_y, h, time) + scheduler_residual = residual - mu_y + xt + xt = self.noise_scheduler.step(scheduler_residual, xt, t_new, num_inference_steps) xt = xt * y_mask return xt[:, :, :y_max_length] diff --git a/src/diffusers/pipelines/pipeline_latent_diffusion.py b/src/diffusers/pipelines/pipeline_latent_diffusion.py index cd7f653bf4..ffc8ae670c 100644 --- a/src/diffusers/pipelines/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/pipeline_latent_diffusion.py @@ -1,17 +1,557 @@ -# pytorch_diffusion + derived encoder decoder import math +from typing import Optional, Tuple, Union import numpy as np import torch import torch.nn as nn +import torch.utils.checkpoint import tqdm + +try: + from transformers.activations import ACT2FN + from transformers.configuration_utils import PretrainedConfig + from transformers.modeling_outputs import BaseModelOutput + from transformers.modeling_utils import PreTrainedModel + from transformers.utils import logging +except ImportError: + raise ImportError("Please install the transformers.") + from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from ..pipeline_utils import DiffusionPipeline +################################################################################ +# Code for the text transformer model +################################################################################ +""" PyTorch LDMBERT model.""" + + +logger = logging.get_logger(__name__) + +LDMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "ldm-bert", + # See all LDMBert models at https://huggingface.co/models?filter=ldmbert +] + + +LDMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "ldm-bert": "https://huggingface.co/ldm-bert/resolve/main/config.json", +} + + +""" LDMBERT model configuration""" + + +class LDMBertConfig(PretrainedConfig): + model_type = "ldmbert" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} + + def __init__( + self, + vocab_size=30522, + max_position_embeddings=77, + encoder_layers=32, + encoder_ffn_dim=5120, + encoder_attention_heads=8, + head_dim=64, + encoder_layerdrop=0.0, + activation_function="gelu", + d_model=1280, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + classifier_dropout=0.0, + scale_embedding=False, + use_cache=True, + pad_token_id=0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.head_dim = head_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.encoder_layerdrop = encoder_layerdrop + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + self.num_hidden_layers = encoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + + super().__init__(pad_token_id=pad_token_id, **kwargs) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->LDMBert +class LDMBertAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + head_dim: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = False, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = head_dim + self.inner_dim = head_dim * num_heads + + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias) + self.out_proj = nn.Linear(self.inner_dim, embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class LDMBertEncoderLayer(nn.Module): + def __init__(self, config: LDMBertConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = LDMBertAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + head_dim=config.head_dim, + dropout=config.attention_dropout, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + layer_head_mask: torch.FloatTensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.bart.modeling_bart.BartPretrainedModel with Bart->LDMBert +class LDMBertPreTrainedModel(PreTrainedModel): + config_class = LDMBertConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"] + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (LDMBertEncoder,)): + module.gradient_checkpointing = value + + @property + def dummy_inputs(self): + pad_token = self.config.pad_token_id + input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) + dummy_inputs = { + "attention_mask": input_ids.ne(pad_token), + "input_ids": input_ids, + } + return dummy_inputs + + +class LDMBertEncoder(LDMBertPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`LDMBertEncoderLayer`]. + + Args: + config: LDMBertConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: LDMBertConfig): + super().__init__(config) + + self.dropout = config.dropout + + embed_dim = config.d_model + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_position_embeddings + + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim) + self.embed_positions = nn.Embedding(config.max_position_embeddings, embed_dim) + self.layers = nn.ModuleList([LDMBertEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layer_norm = nn.LayerNorm(embed_dim) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + seq_len = input_shape[1] + if position_ids is None: + position_ids = torch.arange(seq_len, dtype=torch.long, device=inputs_embeds.device).expand((1, -1)) + embed_pos = self.embed_positions(position_ids) + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class LDMBertModel(LDMBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.model = LDMBertEncoder(config) + self.to_logits = nn.Linear(config.hidden_size, config.vocab_size) + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + return sequence_output + + def get_timestep_embedding(timesteps, embedding_dim): """ This matches the implementation in Denoising Diffusion Probabilistic Models: @@ -860,7 +1400,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): return dec, posterior -class LatentDiffusion(DiffusionPipeline): +class LatentDiffusionPipeline(DiffusionPipeline): def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler): super().__init__() noise_scheduler = noise_scheduler.set_format("pt") @@ -891,11 +1431,11 @@ class LatentDiffusion(DiffusionPipeline): uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to( torch_device ) - uncond_embeddings = self.bert(uncond_input.input_ids)[0] + uncond_embeddings = self.bert(uncond_input.input_ids) # get text embedding text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device) - text_embedding = self.bert(text_input.input_ids)[0] + text_embedding = self.bert(text_input.input_ids) num_trained_timesteps = self.noise_scheduler.config.timesteps inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) diff --git a/src/diffusers/pipelines/pipeline_pndm.py b/src/diffusers/pipelines/pipeline_pndm.py index a19f933ed1..5fd8a98483 100644 --- a/src/diffusers/pipelines/pipeline_pndm.py +++ b/src/diffusers/pipelines/pipeline_pndm.py @@ -21,7 +21,7 @@ import tqdm from ..pipeline_utils import DiffusionPipeline -class PNDM(DiffusionPipeline): +class PNDMPipeline(DiffusionPipeline): def __init__(self, unet, noise_scheduler): super().__init__() noise_scheduler = noise_scheduler.set_format("pt") diff --git a/src/diffusers/pipelines/pipeline_score_sde_ve.py b/src/diffusers/pipelines/pipeline_score_sde_ve.py new file mode 100644 index 0000000000..1dfd304d83 --- /dev/null +++ b/src/diffusers/pipelines/pipeline_score_sde_ve.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +import torch + +from diffusers import DiffusionPipeline + + +# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names +class ScoreSdeVePipeline(DiffusionPipeline): + def __init__(self, model, scheduler): + super().__init__() + self.register_modules(model=model, scheduler=scheduler) + + def __call__(self, num_inference_steps=2000, generator=None): + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + img_size = self.model.config.image_size + channels = self.model.config.num_channels + shape = (1, channels, img_size, img_size) + + model = self.model.to(device) + + # TODO(Patrick) move to scheduler config + n_steps = 1 + + x = torch.randn(*shape) * self.scheduler.config.sigma_max + x = x.to(device) + + self.scheduler.set_timesteps(num_inference_steps) + self.scheduler.set_sigmas(num_inference_steps) + + for i, t in enumerate(self.scheduler.timesteps): + sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device) + + for _ in range(n_steps): + with torch.no_grad(): + result = self.model(x, sigma_t) + x = self.scheduler.step_correct(result, x) + + with torch.no_grad(): + result = model(x, sigma_t) + + x, x_mean = self.scheduler.step_pred(result, x, t) + + return x_mean diff --git a/src/diffusers/pipelines/pipeline_score_sde_vp.py b/src/diffusers/pipelines/pipeline_score_sde_vp.py new file mode 100644 index 0000000000..29551d9a6e --- /dev/null +++ b/src/diffusers/pipelines/pipeline_score_sde_vp.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +import torch + +from diffusers import DiffusionPipeline + + +# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names +class ScoreSdeVpPipeline(DiffusionPipeline): + def __init__(self, model, scheduler): + super().__init__() + self.register_modules(model=model, scheduler=scheduler) + + def __call__(self, num_inference_steps=1000, generator=None): + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + img_size = self.model.config.image_size + channels = self.model.config.num_channels + shape = (1, channels, img_size, img_size) + + model = self.model.to(device) + + x = torch.randn(*shape).to(device) + + self.scheduler.set_timesteps(num_inference_steps) + + for t in self.scheduler.timesteps: + t = t * torch.ones(shape[0], device=device) + scaled_t = t * (num_inference_steps - 1) + + with torch.no_grad(): + result = model(x, scaled_t) + + x, x_mean = self.scheduler.step_pred(result, x, t) + + x_mean = (x_mean + 1.0) / 2.0 + + return x_mean diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index b2d533d380..ad66fe5991 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -20,4 +20,6 @@ from .scheduling_ddim import DDIMScheduler from .scheduling_ddpm import DDPMScheduler from .scheduling_grad_tts import GradTTSScheduler from .scheduling_pndm import PNDMScheduler +from .scheduling_sde_ve import ScoreSdeVeScheduler +from .scheduling_sde_vp import ScoreSdeVpScheduler from .scheduling_utils import SchedulerMixin diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 331fad0f1e..5dea0b22b3 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -92,9 +92,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one - # For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # and sample from it to get previous sample - # x_{t-1} ~ N(pred_prev_sample, variance) == add variane to pred_sample + # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t] if variance_type is None: diff --git a/src/diffusers/schedulers/scheduling_grad_tts.py b/src/diffusers/schedulers/scheduling_grad_tts.py index 94b5f2ac55..4dc6638de3 100644 --- a/src/diffusers/schedulers/scheduling_grad_tts.py +++ b/src/diffusers/schedulers/scheduling_grad_tts.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np + from ..configuration_utils import ConfigMixin from .scheduling_utils import SchedulerMixin @@ -19,29 +21,34 @@ from .scheduling_utils import SchedulerMixin class GradTTSScheduler(SchedulerMixin, ConfigMixin): def __init__( self, - timesteps=1000, - beta_start=0.0001, - beta_end=0.02, + beta_start=0.05, + beta_end=20, tensor_format="np", ): super().__init__() self.register_to_config( - timesteps=timesteps, beta_start=beta_start, beta_end=beta_end, ) self.set_format(tensor_format=tensor_format) + self.betas = None - def sample_noise(self, timestep): - noise = self.beta_start + (self.beta_end - self.beta_start) * timestep - return noise + def get_timesteps(self, num_inference_steps): + return np.array([(t + 0.5) / num_inference_steps for t in range(num_inference_steps)]) - def step(self, xt, residual, mu, h, timestep): - noise_t = self.sample_noise(timestep) - dxt = 0.5 * (mu - xt - residual) - dxt = dxt * noise_t * h - xt = xt - dxt - return xt + def set_betas(self, num_inference_steps): + timesteps = self.get_timesteps(num_inference_steps) + self.betas = np.array([self.beta_start + (self.beta_end - self.beta_start) * t for t in timesteps]) - def __len__(self): - return len(self.config.timesteps) + def step(self, residual, sample, t, num_inference_steps): + # This is a VE scheduler from https://arxiv.org/pdf/2011.13456.pdf (see Algorithm 2 in Appendix) + if self.betas is None: + self.set_betas(num_inference_steps) + + beta_t = self.betas[t] + beta_t_deriv = beta_t / num_inference_steps + + sample_deriv = residual * beta_t_deriv / 2 + + sample = sample + sample_deriv + return sample diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py new file mode 100644 index 0000000000..79936105b9 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -0,0 +1,84 @@ +# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch + +# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin +from .scheduling_utils import SchedulerMixin + + +class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): + def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, sampling_eps=1e-5, tensor_format="np"): + super().__init__() + self.register_to_config( + snr=snr, + sigma_min=sigma_min, + sigma_max=sigma_max, + sampling_eps=sampling_eps, + ) + + self.sigmas = None + self.discrete_sigmas = None + self.timesteps = None + + def set_timesteps(self, num_inference_steps): + self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) + + def set_sigmas(self, num_inference_steps): + if self.timesteps is None: + self.set_timesteps(num_inference_steps) + + self.discrete_sigmas = torch.exp( + torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps) + ) + self.sigmas = torch.tensor( + [self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps] + ) + + def step_pred(self, result, x, t): + # TODO(Patrick) better comments + non-PyTorch + t = t * torch.ones(x.shape[0], device=x.device) + timestep = (t * (len(self.timesteps) - 1)).long() + + sigma = self.discrete_sigmas.to(t.device)[timestep] + adjacent_sigma = torch.where( + timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(timestep.device) + ) + f = torch.zeros_like(x) + G = torch.sqrt(sigma**2 - adjacent_sigma**2) + + f = f - G[:, None, None, None] ** 2 * result + + z = torch.randn_like(x) + x_mean = x - f + x = x_mean + G[:, None, None, None] * z + return x, x_mean + + def step_correct(self, result, x): + # TODO(Patrick) better comments + non-PyTorch + noise = torch.randn_like(x) + grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean() + noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() + step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2 + step_size = step_size * torch.ones(x.shape[0], device=x.device) + x_mean = x + step_size[:, None, None, None] * result + + x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise + + return x diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py new file mode 100644 index 0000000000..dda32a2742 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -0,0 +1,64 @@ +# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch + +# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin +from .scheduling_utils import SchedulerMixin + + +class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): + def __init__(self, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"): + super().__init__() + self.register_to_config( + beta_min=beta_min, + beta_max=beta_max, + sampling_eps=sampling_eps, + ) + + self.sigmas = None + self.discrete_sigmas = None + self.timesteps = None + + def set_timesteps(self, num_inference_steps): + self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) + + def step_pred(self, result, x, t): + # TODO(Patrick) better comments + non-PyTorch + # postprocess model result + log_mean_coeff = ( + -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min + ) + std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff)) + result = -result / std[:, None, None, None] + + # compute + dt = -1.0 / len(self.timesteps) + + beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min) + drift = -0.5 * beta_t[:, None, None, None] * x + diffusion = torch.sqrt(beta_t) + drift = drift - diffusion[:, None, None, None] ** 2 * result + x_mean = x + drift * dt + + # add noise + z = torch.randn_like(x) + x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * z + + return x, x_mean diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py new file mode 100755 index 0000000000..42a4261081 --- /dev/null +++ b/tests/test_layers_utils.py @@ -0,0 +1,115 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +import tempfile +import unittest + +import numpy as np +import torch + +from diffusers.models.embeddings import get_timestep_embedding +from diffusers.testing_utils import floats_tensor, slow, torch_device + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class EmbeddingsTests(unittest.TestCase): + def test_timestep_embeddings(self): + embedding_dim = 256 + timesteps = torch.arange(16) + + t1 = get_timestep_embedding(timesteps, embedding_dim) + + # first vector should always be composed only of 0's and 1's + assert (t1[0, : embedding_dim // 2] - 0).abs().sum() < 1e-5 + assert (t1[0, embedding_dim // 2 :] - 1).abs().sum() < 1e-5 + + # last element of each vector should be one + assert (t1[:, -1] - 1).abs().sum() < 1e-5 + + # For large embeddings (e.g. 128) the frequency of every vector is higher + # than the previous one which means that the gradients of later vectors are + # ALWAYS higher than the previous ones + grad_mean = np.abs(np.gradient(t1, axis=-1)).mean(axis=1) + + prev_grad = 0.0 + for grad in grad_mean: + assert grad > prev_grad + prev_grad = grad + + def test_timestep_defaults(self): + embedding_dim = 16 + timesteps = torch.arange(10) + + t1 = get_timestep_embedding(timesteps, embedding_dim) + t2 = get_timestep_embedding( + timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, max_period=10_000 + ) + + assert torch.allclose(t1.cpu(), t2.cpu(), 1e-3) + + def test_timestep_flip_sin_cos(self): + embedding_dim = 16 + timesteps = torch.arange(10) + + t1 = get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=True) + t1 = torch.cat([t1[:, embedding_dim // 2 :], t1[:, : embedding_dim // 2]], dim=-1) + + t2 = get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False) + + assert torch.allclose(t1.cpu(), t2.cpu(), 1e-3) + + def test_timestep_downscale_freq_shift(self): + embedding_dim = 16 + timesteps = torch.arange(10) + + t1 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=0) + t2 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=1) + + # get cosine half (vectors that are wrapped into cosine) + cosine_half = (t1 - t2)[:, embedding_dim // 2 :] + + # cosine needs to be negative + assert (np.abs((cosine_half <= 0).numpy()) - 1).sum() < 1e-5 + + def test_sinoid_embeddings_hardcoded(self): + embedding_dim = 64 + timesteps = torch.arange(128) + + # standard unet, score_vde + t1 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=1, flip_sin_to_cos=False) + # glide, ldm + t2 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=0, flip_sin_to_cos=True) + # grad-tts + t3 = get_timestep_embedding(timesteps, embedding_dim, scale=1000) + + assert torch.allclose( + t1[23:26, 47:50].flatten().cpu(), + torch.tensor([0.9646, 0.9804, 0.9892, 0.9615, 0.9787, 0.9882, 0.9582, 0.9769, 0.9872]), + 1e-3, + ) + assert torch.allclose( + t2[23:26, 47:50].flatten().cpu(), + torch.tensor([0.3019, 0.2280, 0.1716, 0.3146, 0.2377, 0.1790, 0.3272, 0.2474, 0.1864]), + 1e-3, + ) + assert torch.allclose( + t3[23:26, 47:50].flatten().cpu(), + torch.tensor([-0.9801, -0.9464, -0.9349, -0.3952, 0.8887, -0.9709, 0.5299, -0.2853, -0.9927]), + 1e-3, + ) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 0c9b106c1d..697a377f8c 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -22,18 +22,24 @@ import numpy as np import torch from diffusers import ( - BDDM, - DDIM, - DDPM, - PNDM, + BDDMPipeline, + DDIMPipeline, DDIMScheduler, + DDPMPipeline, DDPMScheduler, - Glide, + GlidePipeline, GlideSuperResUNetModel, GlideTextToImageUNetModel, - GradTTS, - LatentDiffusion, + GradTTSPipeline, + GradTTSScheduler, + LatentDiffusionPipeline, + NCSNpp, + PNDMPipeline, PNDMScheduler, + ScoreSdeVePipeline, + ScoreSdeVeScheduler, + ScoreSdeVpPipeline, + ScoreSdeVpScheduler, UNetGradTTSModel, UNetLDMModel, UNetModel, @@ -107,7 +113,7 @@ class ModelTesterMixin: new_image = new_model(**inputs_dict) max_diff = (image - new_image).abs().sum().item() - self.assertLessEqual(max_diff, 1e-5, "Models give different forward passes") + self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") def test_determinism(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -425,11 +431,12 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase): emb = torch.randn((1, 16, model.config.transformer_dim)).to(torch_device) time_step = torch.tensor([10] * noise.shape[0], device=torch_device) + model.to(torch_device) with torch.no_grad(): output = model(noise, time_step, emb) output, _ = torch.split(output, 3, dim=1) - output_slice = output[0, -1, -3:, -3:].flatten() + output_slice = output[0, -1, -3:, -3:].cpu().flatten() # fmt: off expected_output_slice = torch.tensor([2.7766, -10.3558, -14.9149, -0.9376, -14.9175, -17.7679, -5.5565, -12.9521, -12.9845]) # fmt: on @@ -583,11 +590,11 @@ class PipelineTesterMixin(unittest.TestCase): model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32) schedular = DDPMScheduler(timesteps=10) - ddpm = DDPM(model, schedular) + ddpm = DDPMPipeline(model, schedular) with tempfile.TemporaryDirectory() as tmpdirname: ddpm.save_pretrained(tmpdirname) - new_ddpm = DDPM.from_pretrained(tmpdirname) + new_ddpm = DDPMPipeline.from_pretrained(tmpdirname) generator = torch.manual_seed(0) @@ -601,7 +608,7 @@ class PipelineTesterMixin(unittest.TestCase): def test_from_pretrained_hub(self): model_path = "fusing/ddpm-cifar10" - ddpm = DDPM.from_pretrained(model_path) + ddpm = DDPMPipeline.from_pretrained(model_path) ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path) ddpm.noise_scheduler.num_timesteps = 10 @@ -624,7 +631,7 @@ class PipelineTesterMixin(unittest.TestCase): noise_scheduler = DDPMScheduler.from_config(model_id) noise_scheduler = noise_scheduler.set_format("pt") - ddpm = DDPM(unet=unet, noise_scheduler=noise_scheduler) + ddpm = DDPMPipeline(unet=unet, noise_scheduler=noise_scheduler) image = ddpm(generator=generator) image_slice = image[0, -1, -3:, -3:].cpu() @@ -641,7 +648,7 @@ class PipelineTesterMixin(unittest.TestCase): unet = UNetModel.from_pretrained(model_id) noise_scheduler = DDIMScheduler(tensor_format="pt") - ddim = DDIM(unet=unet, noise_scheduler=noise_scheduler) + ddim = DDIMPipeline(unet=unet, noise_scheduler=noise_scheduler) image = ddim(generator=generator, eta=0.0) image_slice = image[0, -1, -3:, -3:].cpu() @@ -660,7 +667,7 @@ class PipelineTesterMixin(unittest.TestCase): unet = UNetModel.from_pretrained(model_id) noise_scheduler = PNDMScheduler(tensor_format="pt") - pndm = PNDM(unet=unet, noise_scheduler=noise_scheduler) + pndm = PNDMPipeline(unet=unet, noise_scheduler=noise_scheduler) image = pndm(generator=generator) image_slice = image[0, -1, -3:, -3:].cpu() @@ -672,9 +679,10 @@ class PipelineTesterMixin(unittest.TestCase): assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 @slow + @unittest.skip("Skipping for now as it takes too long") def test_ldm_text2img(self): model_id = "fusing/latent-diffusion-text2im-large" - ldm = LatentDiffusion.from_pretrained(model_id) + ldm = LatentDiffusionPipeline.from_pretrained(model_id) prompt = "A painting of a squirrel eating a burger" generator = torch.manual_seed(0) @@ -686,10 +694,25 @@ class PipelineTesterMixin(unittest.TestCase): expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458]) assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + @slow + def test_ldm_text2img_fast(self): + model_id = "fusing/latent-diffusion-text2im-large" + ldm = LatentDiffusionPipeline.from_pretrained(model_id) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.manual_seed(0) + image = ldm([prompt], generator=generator, num_inference_steps=1) + + image_slice = image[0, -1, -3:, -3:].cpu() + + assert image.shape == (1, 3, 256, 256) + expected_slice = torch.tensor([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344]) + assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + @slow def test_glide_text2img(self): model_id = "fusing/glide-base" - glide = Glide.from_pretrained(model_id) + glide = GlidePipeline.from_pretrained(model_id) prompt = "a pencil sketch of a corgi" generator = torch.manual_seed(0) @@ -704,22 +727,61 @@ class PipelineTesterMixin(unittest.TestCase): @slow def test_grad_tts(self): model_id = "fusing/grad-tts-libri-tts" - grad_tts = GradTTS.from_pretrained(model_id) + grad_tts = GradTTSPipeline.from_pretrained(model_id) + noise_scheduler = GradTTSScheduler() + grad_tts.noise_scheduler = noise_scheduler text = "Hello world, I missed you so much." + generator = torch.manual_seed(0) # generate mel spectograms using text - mel_spec = grad_tts(text) + mel_spec = grad_tts(text, generator=generator) - assert mel_spec.shape == (1, 256, 256, 3) - expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784]) - assert (mel_spec.flatten() - expected_slice).abs().max() < 1e-2 + assert mel_spec.shape == (1, 80, 143) + expected_slice = torch.tensor( + [-6.7584, -6.8347, -6.3293, -6.6437, -6.7233, -6.4684, -6.1187, -6.3172, -6.6890] + ) + assert (mel_spec[0, :3, :3].cpu().flatten() - expected_slice).abs().max() < 1e-2 + + @slow + def test_score_sde_ve_pipeline(self): + torch.manual_seed(0) + + model = NCSNpp.from_pretrained("fusing/ffhq_ncsnpp") + scheduler = ScoreSdeVeScheduler.from_config("fusing/ffhq_ncsnpp") + + sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler) + + image = sde_ve(num_inference_steps=2) + + expected_image_sum = 3382810112.0 + expected_image_mean = 1075.366455078125 + + assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2 + assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4 + + @slow + def test_score_sde_vp_pipeline(self): + + model = NCSNpp.from_pretrained("fusing/cifar10-ddpmpp-vp") + scheduler = ScoreSdeVpScheduler.from_config("fusing/cifar10-ddpmpp-vp") + + sde_vp = ScoreSdeVpPipeline(model=model, scheduler=scheduler) + + torch.manual_seed(0) + image = sde_vp(num_inference_steps=10) + + expected_image_sum = 4183.2012 + expected_image_mean = 1.3617 + + assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2 + assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4 def test_module_from_pipeline(self): model = DiffWave(num_res_layers=4) noise_scheduler = DDPMScheduler(timesteps=12) - bddm = BDDM(model, noise_scheduler) + bddm = BDDMPipeline(model, noise_scheduler) # check if the library name for the diffwave moduel is set to pipeline module self.assertTrue(bddm.config["diffwave"][0] == "pipeline_bddm") @@ -727,6 +789,6 @@ class PipelineTesterMixin(unittest.TestCase): # check if we can save and load the pipeline with tempfile.TemporaryDirectory() as tmpdirname: bddm.save_pretrained(tmpdirname) - _ = BDDM.from_pretrained(tmpdirname) + _ = BDDMPipeline.from_pretrained(tmpdirname) # check if the same works using the DifusionPipeline class _ = DiffusionPipeline.from_pretrained(tmpdirname)