mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Merge remote-tracking branch 'origin/main'
# Conflicts: # tests/test_modeling_utils.py
This commit is contained in:
61
README.md
61
README.md
@@ -226,6 +226,56 @@ image_pil = PIL.Image.fromarray(image_processed[0])
|
||||
image_pil.save("test.png")
|
||||
```
|
||||
|
||||
#### **Example 1024x1024 image generation with SDE VE**
|
||||
|
||||
See [paper](https://arxiv.org/abs/2011.13456) for more information on SDE VE.
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
import PIL.Image
|
||||
import numpy as np
|
||||
|
||||
torch.manual_seed(32)
|
||||
|
||||
score_sde_sv = DiffusionPipeline.from_pretrained("fusing/ffhq_ncsnpp")
|
||||
|
||||
# Note this might take up to 3 minutes on a GPU
|
||||
image = score_sde_sv(num_inference_steps=2000)
|
||||
|
||||
image = image.permute(0, 2, 3, 1).cpu().numpy()
|
||||
image = np.clip(image * 255, 0, 255).astype(np.uint8)
|
||||
image_pil = PIL.Image.fromarray(image[0])
|
||||
|
||||
# save image
|
||||
image_pil.save("test.png")
|
||||
```
|
||||
#### **Example 32x32 image generation with SDE VP**
|
||||
|
||||
See [paper](https://arxiv.org/abs/2011.13456) for more information on SDE VE.
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
import PIL.Image
|
||||
import numpy as np
|
||||
|
||||
torch.manual_seed(32)
|
||||
|
||||
score_sde_sv = DiffusionPipeline.from_pretrained("fusing/cifar10-ddpmpp-deep-vp")
|
||||
|
||||
# Note this might take up to 3 minutes on a GPU
|
||||
image = score_sde_sv(num_inference_steps=1000)
|
||||
|
||||
image = image.permute(0, 2, 3, 1).cpu().numpy()
|
||||
image = np.clip(image * 255, 0, 255).astype(np.uint8)
|
||||
image_pil = PIL.Image.fromarray(image[0])
|
||||
|
||||
# save image
|
||||
image_pil.save("test.png")
|
||||
```
|
||||
|
||||
|
||||
#### **Text to Image generation with Latent Diffusion**
|
||||
|
||||
_Note: To use latent diffusion install transformers from [this branch](https://github.com/patil-suraj/transformers/tree/ldm-bert)._
|
||||
@@ -249,24 +299,24 @@ image_pil = PIL.Image.fromarray(image_processed[0])
|
||||
image_pil.save("test.png")
|
||||
```
|
||||
|
||||
#### **Text to speech with GradTTS and BDDM**
|
||||
#### **Text to speech with GradTTS and BDDMPipeline**
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import BDDM, GradTTS
|
||||
from diffusers import BDDMPipeline, GradTTSPipeline
|
||||
|
||||
torch_device = "cuda"
|
||||
|
||||
# load grad tts and bddm pipelines
|
||||
grad_tts = GradTTS.from_pretrained("fusing/grad-tts-libri-tts")
|
||||
bddm = BDDM.from_pretrained("fusing/diffwave-vocoder-ljspeech")
|
||||
grad_tts = GradTTSPipeline.from_pretrained("fusing/grad-tts-libri-tts")
|
||||
bddm = BDDMPipeline.from_pretrained("fusing/diffwave-vocoder-ljspeech")
|
||||
|
||||
text = "Hello world, I missed you so much."
|
||||
|
||||
# generate mel spectograms using text
|
||||
mel_spec = grad_tts(text, torch_device=torch_device)
|
||||
|
||||
# generate the speech by passing mel spectograms to BDDM pipeline
|
||||
# generate the speech by passing mel spectograms to BDDMPipeline pipeline
|
||||
generator = torch.manual_seed(42)
|
||||
audio = bddm(mel_spec, generator, torch_device=torch_device)
|
||||
|
||||
@@ -288,3 +338,4 @@ wavwrite("generated_audio.wav", sampling_rate, audio.squeeze().cpu().numpy())
|
||||
- [ ] Add more vision models
|
||||
- [ ] Add more speech models
|
||||
- [ ] Add RL model
|
||||
- [ ] Add FID and KID metrics
|
||||
|
||||
153
run.py
Executable file
153
run.py
Executable file
@@ -0,0 +1,153 @@
|
||||
#!/usr/bin/env python3
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
#from configs.ve import ffhq_ncsnpp_continuous as configs
|
||||
# from configs.ve import cifar10_ncsnpp_continuous as configs
|
||||
|
||||
|
||||
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class NewReverseDiffusionPredictor:
|
||||
def __init__(self, score_fn, probability_flow=False, sigma_min=0.0, sigma_max=0.0, N=0):
|
||||
super().__init__()
|
||||
self.sigma_min = sigma_min
|
||||
self.sigma_max = sigma_max
|
||||
self.N = N
|
||||
self.discrete_sigmas = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N))
|
||||
|
||||
self.probability_flow = probability_flow
|
||||
self.score_fn = score_fn
|
||||
|
||||
def discretize(self, x, t):
|
||||
timestep = (t * (self.N - 1)).long()
|
||||
sigma = self.discrete_sigmas.to(t.device)[timestep]
|
||||
adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t),
|
||||
self.discrete_sigmas[timestep - 1].to(t.device))
|
||||
f = torch.zeros_like(x)
|
||||
G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)
|
||||
|
||||
labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
|
||||
result = self.score_fn(x, labels)
|
||||
|
||||
rev_f = f - G[:, None, None, None] ** 2 * result * (0.5 if self.probability_flow else 1.)
|
||||
rev_G = torch.zeros_like(G) if self.probability_flow else G
|
||||
return rev_f, rev_G
|
||||
|
||||
def update_fn(self, x, t):
|
||||
f, G = self.discretize(x, t)
|
||||
z = torch.randn_like(x)
|
||||
x_mean = x - f
|
||||
x = x_mean + G[:, None, None, None] * z
|
||||
return x, x_mean
|
||||
|
||||
|
||||
class NewLangevinCorrector:
|
||||
def __init__(self, score_fn, snr, n_steps, sigma_min=0.0, sigma_max=0.0):
|
||||
super().__init__()
|
||||
self.score_fn = score_fn
|
||||
self.snr = snr
|
||||
self.n_steps = n_steps
|
||||
|
||||
self.sigma_min = sigma_min
|
||||
self.sigma_max = sigma_max
|
||||
|
||||
def update_fn(self, x, t):
|
||||
score_fn = self.score_fn
|
||||
n_steps = self.n_steps
|
||||
target_snr = self.snr
|
||||
# if isinstance(sde, VPSDE) or isinstance(sde, subVPSDE):
|
||||
# timestep = (t * (sde.N - 1) / sde.T).long()
|
||||
# alpha = sde.alphas.to(t.device)[timestep]
|
||||
# else:
|
||||
alpha = torch.ones_like(t)
|
||||
|
||||
for i in range(n_steps):
|
||||
labels = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
|
||||
grad = score_fn(x, labels)
|
||||
noise = torch.randn_like(x)
|
||||
grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
|
||||
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
|
||||
step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha
|
||||
x_mean = x + step_size[:, None, None, None] * grad
|
||||
x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise
|
||||
|
||||
return x, x_mean
|
||||
|
||||
|
||||
|
||||
def save_image(x):
|
||||
image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
|
||||
image_pil = PIL.Image.fromarray(image_processed[0])
|
||||
image_pil.save("../images/hey.png")
|
||||
|
||||
|
||||
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
|
||||
#ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
|
||||
# Note usually we need to restore ema etc...
|
||||
# ema restored checkpoint used from below
|
||||
|
||||
N = 2
|
||||
sigma_min = 0.01
|
||||
sigma_max = 1348
|
||||
sampling_eps = 1e-5
|
||||
batch_size = 1
|
||||
centered = False
|
||||
|
||||
from diffusers import NCSNpp
|
||||
|
||||
model = NCSNpp.from_pretrained("/home/patrick/ffhq_ncsnpp").to(device)
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
img_size = model.module.config.image_size
|
||||
channels = model.module.config.num_channels
|
||||
shape = (batch_size, channels, img_size, img_size)
|
||||
probability_flow = False
|
||||
snr = 0.15
|
||||
n_steps = 1
|
||||
|
||||
|
||||
new_corrector = NewLangevinCorrector(score_fn=model, snr=snr, n_steps=n_steps, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
new_predictor = NewReverseDiffusionPredictor(score_fn=model, sigma_min=sigma_min, sigma_max=sigma_max, N=N)
|
||||
|
||||
with torch.no_grad():
|
||||
# Initial sample
|
||||
x = torch.randn(*shape) * sigma_max
|
||||
x = x.to(device)
|
||||
timesteps = torch.linspace(1, sampling_eps, N, device=device)
|
||||
|
||||
for i in range(N):
|
||||
t = timesteps[i]
|
||||
vec_t = torch.ones(shape[0], device=t.device) * t
|
||||
x, x_mean = new_corrector.update_fn(x, vec_t)
|
||||
x, x_mean = new_predictor.update_fn(x, vec_t)
|
||||
|
||||
x = x_mean
|
||||
if centered:
|
||||
x = (x + 1.) / 2.
|
||||
|
||||
|
||||
# save_image(x)
|
||||
|
||||
# for 5 cifar10
|
||||
x_sum = 106071.9922
|
||||
x_mean = 34.52864456176758
|
||||
|
||||
# for 1000 cifar10
|
||||
x_sum = 461.9700
|
||||
x_mean = 0.1504
|
||||
|
||||
# for 2 for 1024
|
||||
x_sum = 3382810112.0
|
||||
x_mean = 1075.366455078125
|
||||
|
||||
def check_x_sum_x_mean(x, x_sum, x_mean):
|
||||
assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
|
||||
assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
|
||||
|
||||
|
||||
check_x_sum_x_mean(x, x_sum, x_mean)
|
||||
@@ -7,23 +7,29 @@ from .utils import is_inflect_available, is_transformers_available, is_unidecode
|
||||
__version__ = "0.0.4"
|
||||
|
||||
from .modeling_utils import ModelMixin
|
||||
from .models.unet import UNetModel
|
||||
from .models.unet_ldm import UNetLDMModel
|
||||
from .models.unet_rl import TemporalUNet
|
||||
from .models import NCSNpp, TemporalUNet, UNetLDMModel, UNetModel
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .pipelines import BDDM, DDIM, DDPM, PNDM
|
||||
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin
|
||||
from .pipelines import BDDMPipeline, DDIMPipeline, DDPMPipeline, PNDMPipeline, ScoreSdeVePipeline, ScoreSdeVpPipeline
|
||||
from .schedulers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
GradTTSScheduler,
|
||||
PNDMScheduler,
|
||||
SchedulerMixin,
|
||||
ScoreSdeVeScheduler,
|
||||
ScoreSdeVpScheduler,
|
||||
)
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from .models.unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel
|
||||
from .models.unet_grad_tts import UNetGradTTSModel
|
||||
from .pipelines import Glide, LatentDiffusion
|
||||
from .pipelines import GlidePipeline, LatentDiffusionPipeline
|
||||
else:
|
||||
from .utils.dummy_transformers_objects import *
|
||||
|
||||
|
||||
if is_transformers_available() and is_inflect_available() and is_unidecode_available():
|
||||
from .pipelines import GradTTS
|
||||
from .pipelines import GradTTSPipeline
|
||||
else:
|
||||
from .utils.dummy_transformers_and_inflect_and_unidecode_objects import *
|
||||
|
||||
@@ -21,3 +21,4 @@ from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, Glide
|
||||
from .unet_grad_tts import UNetGradTTSModel
|
||||
from .unet_ldm import UNetLDMModel
|
||||
from .unet_rl import TemporalUNet
|
||||
from .unet_sde_score_estimation import NCSNpp
|
||||
|
||||
0
src/diffusers/models/attention2d.py
Normal file
0
src/diffusers/models/attention2d.py
Normal file
85
src/diffusers/models/embeddings.py
Normal file
85
src/diffusers/models/embeddings.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, scale=1, max_period=10000
|
||||
):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
Create sinusoidal timestep embeddings.
|
||||
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param embedding_dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
|
||||
emb_coeff = -math.log(max_period) / (half_dim - downscale_freq_shift)
|
||||
emb = torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
|
||||
emb = torch.exp(emb * emb_coeff)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
# scale embeddings
|
||||
emb = scale * emb
|
||||
|
||||
# concat sine and cosine embeddings
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||
|
||||
# flip sine and cosine embeddings
|
||||
if flip_sin_to_cos:
|
||||
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||
|
||||
# zero pad
|
||||
if embedding_dim % 2 == 1:
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
# unet_sde_score_estimation.py
|
||||
class GaussianFourierProjection(nn.Module):
|
||||
"""Gaussian Fourier embeddings for noise levels."""
|
||||
|
||||
def __init__(self, embedding_size=256, scale=1.0):
|
||||
super().__init__()
|
||||
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
|
||||
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
||||
|
||||
|
||||
# unet_rl.py - TODO(need test)
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
278
src/diffusers/models/resnet.py
Normal file
278
src/diffusers/models/resnet.py
Normal file
@@ -0,0 +1,278 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.AvgPool1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
def conv_transpose_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.ConvTranspose1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.ConvTranspose2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.ConvTranspose3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
def nonlinearity(x, swish=1.0):
|
||||
# swish
|
||||
if swish == 1.0:
|
||||
return F.silu(x)
|
||||
else:
|
||||
return x * F.sigmoid(x * float(swish))
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, use_conv_transpose=False, dims=2, out_channels=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
|
||||
if use_conv_transpose:
|
||||
self.conv = conv_transpose_nd(dims, channels, out_channels, 4, 2, 1)
|
||||
elif use_conv:
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(x)
|
||||
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
self.padding = padding
|
||||
stride = 2 if dims != 3 else (1, 2, 2)
|
||||
if use_conv:
|
||||
self.down = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.down = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.use_conv and self.padding == 0 and self.dims == 2:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = F.pad(x, pad, mode="constant", value=0)
|
||||
return self.down(x)
|
||||
|
||||
|
||||
class UNetUpsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
class GlideUpsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class LDMUpsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class GradTTSUpsample(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super(Upsample, self).__init__()
|
||||
self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Upsample1d(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
# class ResnetBlock(nn.Module):
|
||||
# def __init__(
|
||||
# self,
|
||||
# *,
|
||||
# in_channels,
|
||||
# out_channels=None,
|
||||
# conv_shortcut=False,
|
||||
# dropout,
|
||||
# temb_channels=512,
|
||||
# use_scale_shift_norm=False,
|
||||
# ):
|
||||
# super().__init__()
|
||||
# self.in_channels = in_channels
|
||||
# out_channels = in_channels if out_channels is None else out_channels
|
||||
# self.out_channels = out_channels
|
||||
# self.use_conv_shortcut = conv_shortcut
|
||||
# self.use_scale_shift_norm = use_scale_shift_norm
|
||||
|
||||
# self.norm1 = Normalize(in_channels)
|
||||
# self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# temp_out_channles = 2 * out_channels if use_scale_shift_norm else out_channels
|
||||
# self.temb_proj = torch.nn.Linear(temb_channels, temp_out_channles)
|
||||
|
||||
# self.norm2 = Normalize(out_channels)
|
||||
# self.dropout = torch.nn.Dropout(dropout)
|
||||
# self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
# if self.in_channels != self.out_channels:
|
||||
# if self.use_conv_shortcut:
|
||||
# self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
# else:
|
||||
# self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
# def forward(self, x, temb):
|
||||
# h = x
|
||||
# h = self.norm1(h)
|
||||
# h = nonlinearity(h)
|
||||
# h = self.conv1(h)
|
||||
|
||||
# # TODO: check if this broadcasting works correctly for 1D and 3D
|
||||
# temb = self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
# if self.use_scale_shift_norm:
|
||||
# out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
||||
# scale, shift = torch.chunk(temb, 2, dim=1)
|
||||
# h = self.norm2(h) * (1 + scale) + shift
|
||||
# h = out_rest(h)
|
||||
# else:
|
||||
# h = h + temb
|
||||
# h = self.norm2(h)
|
||||
# h = nonlinearity(h)
|
||||
# h = self.dropout(h)
|
||||
# h = self.conv2(h)
|
||||
|
||||
# if self.in_channels != self.out_channels:
|
||||
# if self.use_conv_shortcut:
|
||||
# x = self.conv_shortcut(x)
|
||||
# else:
|
||||
# x = self.nin_shortcut(x)
|
||||
|
||||
# return x + h
|
||||
@@ -30,27 +30,7 @@ from tqdm import tqdm
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
Build sinusoidal embeddings.
|
||||
This matches the implementation in tensor2tensor, but differs slightly
|
||||
from the description in Section 3.5 of "Attention Is All You Need".
|
||||
"""
|
||||
assert len(timesteps.shape) == 1
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
from .embeddings import get_timestep_embedding
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .embeddings import get_timestep_embedding
|
||||
|
||||
|
||||
def convert_module_to_f16(l):
|
||||
@@ -86,27 +87,6 @@ def normalization(channels, swish=0.0):
|
||||
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||
device=timesteps.device
|
||||
)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
@@ -627,7 +607,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
hs = []
|
||||
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
emb = self.time_embed(
|
||||
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
@@ -714,7 +696,9 @@ class GlideTextToImageUNetModel(GlideUNetModel):
|
||||
|
||||
def forward(self, x, timesteps, transformer_out=None):
|
||||
hs = []
|
||||
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
emb = self.time_embed(
|
||||
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
)
|
||||
|
||||
# project the last token
|
||||
transformer_proj = self.transformer_proj(transformer_out[:, -1])
|
||||
@@ -806,7 +790,9 @@ class GlideSuperResUNetModel(GlideUNetModel):
|
||||
x = torch.cat([x, upsampled], dim=1)
|
||||
|
||||
hs = []
|
||||
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
emb = self.time_embed(
|
||||
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
)
|
||||
|
||||
h = x
|
||||
for module in self.input_blocks:
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -11,6 +9,7 @@ except:
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .embeddings import get_timestep_embedding
|
||||
|
||||
|
||||
class Mish(torch.nn.Module):
|
||||
@@ -107,21 +106,6 @@ class Residual(torch.nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class SinusoidalPosEmb(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super(SinusoidalPosEmb, self).__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x, scale=1000):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
||||
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||
def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000):
|
||||
super(UNetGradTTSModel, self).__init__()
|
||||
@@ -149,7 +133,6 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||
torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats)
|
||||
)
|
||||
|
||||
self.time_pos_emb = SinusoidalPosEmb(dim)
|
||||
self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim))
|
||||
|
||||
dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)]
|
||||
@@ -198,7 +181,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||
if not isinstance(spk, type(None)):
|
||||
s = self.spk_mlp(spk)
|
||||
|
||||
t = self.time_pos_emb(timesteps, scale=self.pe_scale)
|
||||
t = get_timestep_embedding(timesteps, self.dim, scale=self.pe_scale)
|
||||
t = self.mlp(t)
|
||||
|
||||
if self.n_spks < 2:
|
||||
|
||||
@@ -16,6 +16,7 @@ except:
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .embeddings import get_timestep_embedding
|
||||
|
||||
|
||||
def exists(val):
|
||||
@@ -316,36 +317,6 @@ def normalization(channels, swish=0.0):
|
||||
return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||
device=timesteps.device
|
||||
)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
## go
|
||||
class AttentionPool2d(nn.Module):
|
||||
"""
|
||||
@@ -1026,7 +997,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||
hs = []
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=x.device)
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels)
|
||||
t_emb = get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
@@ -1240,7 +1211,9 @@ class EncoderUNetModel(nn.Module):
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:return: an [N x K] Tensor of outputs.
|
||||
"""
|
||||
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
emb = self.time_embed(
|
||||
get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
)
|
||||
|
||||
results = []
|
||||
h = x.type(self.dtype)
|
||||
|
||||
@@ -13,7 +13,6 @@ except:
|
||||
print("Einops is not installed")
|
||||
pass
|
||||
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
@@ -107,14 +106,21 @@ class ResidualTemporalBlock(nn.Module):
|
||||
class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
horizon,
|
||||
training_horizon,
|
||||
transition_dim,
|
||||
cond_dim,
|
||||
predict_epsilon=False,
|
||||
clip_denoised=True,
|
||||
dim=32,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.transition_dim = transition_dim
|
||||
self.cond_dim = cond_dim
|
||||
self.predict_epsilon = predict_epsilon
|
||||
self.clip_denoised = clip_denoised
|
||||
|
||||
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
# print(f'[ models/temporal ] Channel dimensions: {in_out}')
|
||||
@@ -138,19 +144,19 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
||||
self.downs.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=horizon),
|
||||
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=horizon),
|
||||
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim, horizon=training_horizon),
|
||||
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim, horizon=training_horizon),
|
||||
Downsample1d(dim_out) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
if not is_last:
|
||||
horizon = horizon // 2
|
||||
training_horizon = training_horizon // 2
|
||||
|
||||
mid_dim = dims[-1]
|
||||
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon)
|
||||
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon)
|
||||
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
|
||||
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=training_horizon)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
@@ -158,15 +164,15 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
|
||||
self.ups.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon),
|
||||
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon),
|
||||
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=training_horizon),
|
||||
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=training_horizon),
|
||||
Upsample1d(dim_in) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
if not is_last:
|
||||
horizon = horizon * 2
|
||||
training_horizon = training_horizon * 2
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
Conv1dBlock(dim, dim, kernel_size=5),
|
||||
@@ -232,7 +238,6 @@ class TemporalValue(nn.Module):
|
||||
|
||||
print(in_out)
|
||||
for dim_in, dim_out in in_out:
|
||||
|
||||
self.blocks.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
|
||||
1061
src/diffusers/models/unet_sde_score_estimation.py
Normal file
1061
src/diffusers/models/unet_sde_score_estimation.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -21,7 +21,6 @@ from typing import Optional, Union
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .dynamic_modules_utils import get_class_from_dynamic_module
|
||||
from .utils import DIFFUSERS_CACHE, logging
|
||||
|
||||
|
||||
@@ -81,16 +80,13 @@ class DiffusionPipeline(ConfigMixin):
|
||||
# set models
|
||||
setattr(self, name, module)
|
||||
|
||||
register_dict = {"_module": self.__module__.split(".")[-1]}
|
||||
self.register_to_config(**register_dict)
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
||||
self.save_config(save_directory)
|
||||
|
||||
model_index_dict = dict(self.config)
|
||||
model_index_dict.pop("_class_name")
|
||||
model_index_dict.pop("_diffusers_version")
|
||||
model_index_dict.pop("_module")
|
||||
model_index_dict.pop("_module", None)
|
||||
|
||||
for pipeline_component_name in model_index_dict.keys():
|
||||
sub_model = getattr(self, pipeline_component_name)
|
||||
@@ -139,11 +135,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
config_dict = cls.get_config_dict(cached_folder)
|
||||
|
||||
# 2. Get class name and module candidates to load custom models
|
||||
module_candidate_name = config_dict["_module"]
|
||||
module_candidate = module_candidate_name + ".py"
|
||||
|
||||
# 3. Load the pipeline class, if using custom module then load it from the hub
|
||||
# 2. Load the pipeline class, if using custom module then load it from the hub
|
||||
# if we load from explicit class, let's use it
|
||||
if cls != DiffusionPipeline:
|
||||
pipeline_class = cls
|
||||
@@ -151,11 +143,6 @@ class DiffusionPipeline(ConfigMixin):
|
||||
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
|
||||
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
|
||||
|
||||
# (TODO - we should allow to load custom pipelines
|
||||
# else we need to load the correct module from the Hub
|
||||
# module = module_candidate
|
||||
# pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
|
||||
|
||||
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
init_kwargs = {}
|
||||
@@ -163,7 +150,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
# import it here to avoid circular import
|
||||
from diffusers import pipelines
|
||||
|
||||
# 4. Load each module in the pipeline
|
||||
# 3. Load each module in the pipeline
|
||||
for name, (library_name, class_name) in init_dict.items():
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
# if the model is in a pipeline module, then we load it from the pipeline
|
||||
@@ -171,14 +158,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
pipeline_module = getattr(pipelines, library_name)
|
||||
class_obj = getattr(pipeline_module, class_name)
|
||||
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||
class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()}
|
||||
elif library_name == module_candidate_name:
|
||||
# if the model is not in diffusers or transformers, we need to load it from the hub
|
||||
# assumes that it's a subclass of ModelMixin
|
||||
class_obj = get_class_from_dynamic_module(cached_folder, module_candidate, class_name, cached_folder)
|
||||
# since it's not from a library, we need to check class candidates for all importable classes
|
||||
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||
class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()}
|
||||
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
||||
else:
|
||||
# else we just import it from the library.
|
||||
library = importlib.import_module(library_name)
|
||||
|
||||
@@ -15,5 +15,5 @@ TODO(Patrick, Anton, Suraj)
|
||||
- PNDM for unconditional image generation in [pipeline_pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py).
|
||||
- Latent diffusion for text to image generation / conditional image generation in [pipeline_latent_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_latent_diffusion.py).
|
||||
- Glide for text to image generation / conditional image generation in [pipeline_glide](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_glide.py).
|
||||
- BDDM for spectrogram-to-sound vocoding in [pipeline_bddm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py).
|
||||
- BDDMPipeline for spectrogram-to-sound vocoding in [pipeline_bddm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_bddm.py).
|
||||
- Grad-TTS for text to audio generation / conditional audio generation in [pipeline_grad_tts](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_grad_tts.py).
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
from ..utils import is_inflect_available, is_transformers_available, is_unidecode_available
|
||||
from .pipeline_bddm import BDDM
|
||||
from .pipeline_ddim import DDIM
|
||||
from .pipeline_ddpm import DDPM
|
||||
from .pipeline_pndm import PNDM
|
||||
from .pipeline_bddm import BDDMPipeline
|
||||
from .pipeline_ddim import DDIMPipeline
|
||||
from .pipeline_ddpm import DDPMPipeline
|
||||
from .pipeline_pndm import PNDMPipeline
|
||||
from .pipeline_score_sde_ve import ScoreSdeVePipeline
|
||||
from .pipeline_score_sde_vp import ScoreSdeVpPipeline
|
||||
|
||||
|
||||
# from .pipeline_score_sde import ScoreSdeVePipeline
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from .pipeline_glide import Glide
|
||||
from .pipeline_latent_diffusion import LatentDiffusion
|
||||
from .pipeline_glide import GlidePipeline
|
||||
from .pipeline_latent_diffusion import LatentDiffusionPipeline
|
||||
|
||||
|
||||
if is_transformers_available() and is_unidecode_available() and is_inflect_available():
|
||||
from .pipeline_grad_tts import GradTTS
|
||||
from .pipeline_grad_tts import GradTTSPipeline
|
||||
|
||||
@@ -271,7 +271,7 @@ class DiffWave(ModelMixin, ConfigMixin):
|
||||
return self.final_conv(x)
|
||||
|
||||
|
||||
class BDDM(DiffusionPipeline):
|
||||
class BDDMPipeline(DiffusionPipeline):
|
||||
def __init__(self, diffwave, noise_scheduler):
|
||||
super().__init__()
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
|
||||
@@ -21,7 +21,7 @@ import tqdm
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
class DDIM(DiffusionPipeline):
|
||||
class DDIMPipeline(DiffusionPipeline):
|
||||
def __init__(self, unet, noise_scheduler):
|
||||
super().__init__()
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
|
||||
@@ -21,7 +21,7 @@ import tqdm
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
class DDPM(DiffusionPipeline):
|
||||
class DDPMPipeline(DiffusionPipeline):
|
||||
def __init__(self, unet, noise_scheduler):
|
||||
super().__init__()
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
|
||||
@@ -695,7 +695,23 @@ class CLIPTextModel(CLIPPreTrainedModel):
|
||||
#####################
|
||||
|
||||
|
||||
class Glide(DiffusionPipeline):
|
||||
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
||||
"""
|
||||
Extract values from a 1-D numpy array for a batch of indices.
|
||||
|
||||
:param arr: the 1-D numpy array.
|
||||
:param timesteps: a tensor of indices into the array to extract.
|
||||
:param broadcast_shape: a larger shape of K dimensions with the batch
|
||||
dimension equal to the length of timesteps.
|
||||
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
||||
"""
|
||||
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
||||
while len(res.shape) < len(broadcast_shape):
|
||||
res = res[..., None]
|
||||
return res + torch.zeros(broadcast_shape, device=timesteps.device)
|
||||
|
||||
|
||||
class GlidePipeline(DiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
text_unet: GlideTextToImageUNetModel,
|
||||
|
||||
@@ -420,7 +420,7 @@ class TextEncoder(ModelMixin, ConfigMixin):
|
||||
return mu, logw, x_mask
|
||||
|
||||
|
||||
class GradTTS(DiffusionPipeline):
|
||||
class GradTTSPipeline(DiffusionPipeline):
|
||||
def __init__(self, unet, text_encoder, noise_scheduler, tokenizer):
|
||||
super().__init__()
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
@@ -430,7 +430,14 @@ class GradTTS(DiffusionPipeline):
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, text, num_inference_steps=50, temperature=1.3, length_scale=0.91, speaker_id=15, torch_device=None
|
||||
self,
|
||||
text,
|
||||
num_inference_steps=50,
|
||||
temperature=1.3,
|
||||
length_scale=0.91,
|
||||
speaker_id=15,
|
||||
torch_device=None,
|
||||
generator=None,
|
||||
):
|
||||
if torch_device is None:
|
||||
torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
@@ -464,17 +471,19 @@ class GradTTS(DiffusionPipeline):
|
||||
mu_y = mu_y.transpose(1, 2)
|
||||
|
||||
# Sample latent representation from terminal distribution N(mu_y, I)
|
||||
z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature
|
||||
z = mu_y + torch.randn(mu_y.shape, generator=generator).to(mu_y.device)
|
||||
|
||||
xt = z * y_mask
|
||||
h = 1.0 / num_inference_steps
|
||||
# (Patrick: TODO)
|
||||
for t in tqdm.tqdm(range(num_inference_steps), total=num_inference_steps):
|
||||
t_new = num_inference_steps - t - 1
|
||||
t = (1.0 - (t + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device)
|
||||
time = t.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
residual = self.unet(xt, t, mu_y, y_mask, speaker_id)
|
||||
|
||||
xt = self.noise_scheduler.step(xt, residual, mu_y, h, time)
|
||||
scheduler_residual = residual - mu_y + xt
|
||||
xt = self.noise_scheduler.step(scheduler_residual, xt, t_new, num_inference_steps)
|
||||
xt = xt * y_mask
|
||||
|
||||
return xt[:, :, :y_max_length]
|
||||
|
||||
@@ -1,17 +1,557 @@
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
import tqdm
|
||||
|
||||
|
||||
try:
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_outputs import BaseModelOutput
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import logging
|
||||
except ImportError:
|
||||
raise ImportError("Please install the transformers.")
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
################################################################################
|
||||
# Code for the text transformer model
|
||||
################################################################################
|
||||
""" PyTorch LDMBERT model."""
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
LDMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"ldm-bert",
|
||||
# See all LDMBert models at https://huggingface.co/models?filter=ldmbert
|
||||
]
|
||||
|
||||
|
||||
LDMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"ldm-bert": "https://huggingface.co/ldm-bert/resolve/main/config.json",
|
||||
}
|
||||
|
||||
|
||||
""" LDMBERT model configuration"""
|
||||
|
||||
|
||||
class LDMBertConfig(PretrainedConfig):
|
||||
model_type = "ldmbert"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=30522,
|
||||
max_position_embeddings=77,
|
||||
encoder_layers=32,
|
||||
encoder_ffn_dim=5120,
|
||||
encoder_attention_heads=8,
|
||||
head_dim=64,
|
||||
encoder_layerdrop=0.0,
|
||||
activation_function="gelu",
|
||||
d_model=1280,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.0,
|
||||
activation_dropout=0.0,
|
||||
init_std=0.02,
|
||||
classifier_dropout=0.0,
|
||||
scale_embedding=False,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.d_model = d_model
|
||||
self.encoder_ffn_dim = encoder_ffn_dim
|
||||
self.encoder_layers = encoder_layers
|
||||
self.encoder_attention_heads = encoder_attention_heads
|
||||
self.head_dim = head_dim
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
self.activation_function = activation_function
|
||||
self.init_std = init_std
|
||||
self.encoder_layerdrop = encoder_layerdrop
|
||||
self.classifier_dropout = classifier_dropout
|
||||
self.use_cache = use_cache
|
||||
self.num_hidden_layers = encoder_layers
|
||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||
|
||||
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
||||
|
||||
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->LDMBert
|
||||
class LDMBertAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
dropout: float = 0.0,
|
||||
is_decoder: bool = False,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.head_dim = head_dim
|
||||
self.inner_dim = head_dim * num_heads
|
||||
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.is_decoder = is_decoder
|
||||
|
||||
self.k_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias)
|
||||
self.v_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias)
|
||||
self.q_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(self.inner_dim, embed_dim)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scaling
|
||||
# get key, value proj
|
||||
if is_cross_attention and past_key_value is not None:
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
if layer_head_mask.size() != (self.num_heads,):
|
||||
raise ValueError(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
||||
f" {layer_head_mask.size()}"
|
||||
)
|
||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit awkward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to be reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned aross GPUs when using tensor-parallelism.
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights_reshaped, past_key_value
|
||||
|
||||
|
||||
class LDMBertEncoderLayer(nn.Module):
|
||||
def __init__(self, config: LDMBertConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
self.self_attn = LDMBertAttention(
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
head_dim=config.head_dim,
|
||||
dropout=config.attention_dropout,
|
||||
)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.dropout = config.dropout
|
||||
self.activation_fn = ACT2FN[config.activation_function]
|
||||
self.activation_dropout = config.activation_dropout
|
||||
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
||||
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
attention_mask: torch.FloatTensor,
|
||||
layer_head_mask: torch.FloatTensor,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
||||
`(encoder_attention_heads,)`.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states, attn_weights, _ = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
layer_head_mask=layer_head_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
if hidden_states.dtype == torch.float16 and (
|
||||
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
||||
):
|
||||
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
||||
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartPretrainedModel with Bart->LDMBert
|
||||
class LDMBertPreTrainedModel(PreTrainedModel):
|
||||
config_class = LDMBertConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (LDMBertEncoder,)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
pad_token = self.config.pad_token_id
|
||||
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
|
||||
dummy_inputs = {
|
||||
"attention_mask": input_ids.ne(pad_token),
|
||||
"input_ids": input_ids,
|
||||
}
|
||||
return dummy_inputs
|
||||
|
||||
|
||||
class LDMBertEncoder(LDMBertPreTrainedModel):
|
||||
"""
|
||||
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
||||
[`LDMBertEncoderLayer`].
|
||||
|
||||
Args:
|
||||
config: LDMBertConfig
|
||||
embed_tokens (nn.Embedding): output embedding
|
||||
"""
|
||||
|
||||
def __init__(self, config: LDMBertConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.dropout = config.dropout
|
||||
|
||||
embed_dim = config.d_model
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_source_positions = config.max_position_embeddings
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim)
|
||||
self.embed_positions = nn.Embedding(config.max_position_embeddings, embed_dim)
|
||||
self.layers = nn.ModuleList([LDMBertEncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||
self.layer_norm = nn.LayerNorm(embed_dim)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutput]:
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
||||
provide it.
|
||||
|
||||
Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
||||
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
||||
than the model's internal embedding lookup matrix.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||
for more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
seq_len = input_shape[1]
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(seq_len, dtype=torch.long, device=inputs_embeds.device).expand((1, -1))
|
||||
embed_pos = self.embed_positions(position_ids)
|
||||
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
|
||||
# expand attention_mask
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
|
||||
|
||||
encoder_states = () if output_hidden_states else None
|
||||
all_attentions = () if output_attentions else None
|
||||
|
||||
# check if head_mask has a correct number of layers specified if desired
|
||||
if head_mask is not None:
|
||||
if head_mask.size()[0] != (len(self.layers)):
|
||||
raise ValueError(
|
||||
f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
|
||||
f" {head_mask.size()[0]}."
|
||||
)
|
||||
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(encoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
(head_mask[idx] if head_mask is not None else None),
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
||||
)
|
||||
|
||||
|
||||
class LDMBertModel(LDMBertPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = LDMBertEncoder(config)
|
||||
self.to_logits = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
|
||||
outputs = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
sequence_output = outputs[0]
|
||||
return sequence_output
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
@@ -860,7 +1400,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
return dec, posterior
|
||||
|
||||
|
||||
class LatentDiffusion(DiffusionPipeline):
|
||||
class LatentDiffusionPipeline(DiffusionPipeline):
|
||||
def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler):
|
||||
super().__init__()
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
@@ -891,11 +1431,11 @@ class LatentDiffusion(DiffusionPipeline):
|
||||
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
uncond_embeddings = self.bert(uncond_input.input_ids)[0]
|
||||
uncond_embeddings = self.bert(uncond_input.input_ids)
|
||||
|
||||
# get text embedding
|
||||
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
|
||||
text_embedding = self.bert(text_input.input_ids)[0]
|
||||
text_embedding = self.bert(text_input.input_ids)
|
||||
|
||||
num_trained_timesteps = self.noise_scheduler.config.timesteps
|
||||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
|
||||
|
||||
@@ -21,7 +21,7 @@ import tqdm
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
class PNDM(DiffusionPipeline):
|
||||
class PNDMPipeline(DiffusionPipeline):
|
||||
def __init__(self, unet, noise_scheduler):
|
||||
super().__init__()
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
|
||||
44
src/diffusers/pipelines/pipeline_score_sde_ve.py
Normal file
44
src/diffusers/pipelines/pipeline_score_sde_ve.py
Normal file
@@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env python3
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
|
||||
# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names
|
||||
class ScoreSdeVePipeline(DiffusionPipeline):
|
||||
def __init__(self, model, scheduler):
|
||||
super().__init__()
|
||||
self.register_modules(model=model, scheduler=scheduler)
|
||||
|
||||
def __call__(self, num_inference_steps=2000, generator=None):
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
img_size = self.model.config.image_size
|
||||
channels = self.model.config.num_channels
|
||||
shape = (1, channels, img_size, img_size)
|
||||
|
||||
model = self.model.to(device)
|
||||
|
||||
# TODO(Patrick) move to scheduler config
|
||||
n_steps = 1
|
||||
|
||||
x = torch.randn(*shape) * self.scheduler.config.sigma_max
|
||||
x = x.to(device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.set_sigmas(num_inference_steps)
|
||||
|
||||
for i, t in enumerate(self.scheduler.timesteps):
|
||||
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device)
|
||||
|
||||
for _ in range(n_steps):
|
||||
with torch.no_grad():
|
||||
result = self.model(x, sigma_t)
|
||||
x = self.scheduler.step_correct(result, x)
|
||||
|
||||
with torch.no_grad():
|
||||
result = model(x, sigma_t)
|
||||
|
||||
x, x_mean = self.scheduler.step_pred(result, x, t)
|
||||
|
||||
return x_mean
|
||||
37
src/diffusers/pipelines/pipeline_score_sde_vp.py
Normal file
37
src/diffusers/pipelines/pipeline_score_sde_vp.py
Normal file
@@ -0,0 +1,37 @@
|
||||
#!/usr/bin/env python3
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
|
||||
# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names
|
||||
class ScoreSdeVpPipeline(DiffusionPipeline):
|
||||
def __init__(self, model, scheduler):
|
||||
super().__init__()
|
||||
self.register_modules(model=model, scheduler=scheduler)
|
||||
|
||||
def __call__(self, num_inference_steps=1000, generator=None):
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
img_size = self.model.config.image_size
|
||||
channels = self.model.config.num_channels
|
||||
shape = (1, channels, img_size, img_size)
|
||||
|
||||
model = self.model.to(device)
|
||||
|
||||
x = torch.randn(*shape).to(device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for t in self.scheduler.timesteps:
|
||||
t = t * torch.ones(shape[0], device=device)
|
||||
scaled_t = t * (num_inference_steps - 1)
|
||||
|
||||
with torch.no_grad():
|
||||
result = model(x, scaled_t)
|
||||
|
||||
x, x_mean = self.scheduler.step_pred(result, x, t)
|
||||
|
||||
x_mean = (x_mean + 1.0) / 2.0
|
||||
|
||||
return x_mean
|
||||
@@ -20,4 +20,6 @@ from .scheduling_ddim import DDIMScheduler
|
||||
from .scheduling_ddpm import DDPMScheduler
|
||||
from .scheduling_grad_tts import GradTTSScheduler
|
||||
from .scheduling_pndm import PNDMScheduler
|
||||
from .scheduling_sde_ve import ScoreSdeVeScheduler
|
||||
from .scheduling_sde_vp import ScoreSdeVpScheduler
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
@@ -92,9 +92,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
|
||||
|
||||
# For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
|
||||
# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
|
||||
# and sample from it to get previous sample
|
||||
# x_{t-1} ~ N(pred_prev_sample, variance) == add variane to pred_sample
|
||||
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
|
||||
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
|
||||
|
||||
if variance_type is None:
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
@@ -19,29 +21,34 @@ from .scheduling_utils import SchedulerMixin
|
||||
class GradTTSScheduler(SchedulerMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
timesteps=1000,
|
||||
beta_start=0.0001,
|
||||
beta_end=0.02,
|
||||
beta_start=0.05,
|
||||
beta_end=20,
|
||||
tensor_format="np",
|
||||
):
|
||||
super().__init__()
|
||||
self.register_to_config(
|
||||
timesteps=timesteps,
|
||||
beta_start=beta_start,
|
||||
beta_end=beta_end,
|
||||
)
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
self.betas = None
|
||||
|
||||
def sample_noise(self, timestep):
|
||||
noise = self.beta_start + (self.beta_end - self.beta_start) * timestep
|
||||
return noise
|
||||
def get_timesteps(self, num_inference_steps):
|
||||
return np.array([(t + 0.5) / num_inference_steps for t in range(num_inference_steps)])
|
||||
|
||||
def step(self, xt, residual, mu, h, timestep):
|
||||
noise_t = self.sample_noise(timestep)
|
||||
dxt = 0.5 * (mu - xt - residual)
|
||||
dxt = dxt * noise_t * h
|
||||
xt = xt - dxt
|
||||
return xt
|
||||
def set_betas(self, num_inference_steps):
|
||||
timesteps = self.get_timesteps(num_inference_steps)
|
||||
self.betas = np.array([self.beta_start + (self.beta_end - self.beta_start) * t for t in timesteps])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.config.timesteps)
|
||||
def step(self, residual, sample, t, num_inference_steps):
|
||||
# This is a VE scheduler from https://arxiv.org/pdf/2011.13456.pdf (see Algorithm 2 in Appendix)
|
||||
if self.betas is None:
|
||||
self.set_betas(num_inference_steps)
|
||||
|
||||
beta_t = self.betas[t]
|
||||
beta_t_deriv = beta_t / num_inference_steps
|
||||
|
||||
sample_deriv = residual * beta_t_deriv / 2
|
||||
|
||||
sample = sample + sample_deriv
|
||||
return sample
|
||||
|
||||
84
src/diffusers/schedulers/scheduling_sde_ve.py
Normal file
84
src/diffusers/schedulers/scheduling_sde_ve.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
|
||||
|
||||
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, sampling_eps=1e-5, tensor_format="np"):
|
||||
super().__init__()
|
||||
self.register_to_config(
|
||||
snr=snr,
|
||||
sigma_min=sigma_min,
|
||||
sigma_max=sigma_max,
|
||||
sampling_eps=sampling_eps,
|
||||
)
|
||||
|
||||
self.sigmas = None
|
||||
self.discrete_sigmas = None
|
||||
self.timesteps = None
|
||||
|
||||
def set_timesteps(self, num_inference_steps):
|
||||
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
|
||||
|
||||
def set_sigmas(self, num_inference_steps):
|
||||
if self.timesteps is None:
|
||||
self.set_timesteps(num_inference_steps)
|
||||
|
||||
self.discrete_sigmas = torch.exp(
|
||||
torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps)
|
||||
)
|
||||
self.sigmas = torch.tensor(
|
||||
[self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps]
|
||||
)
|
||||
|
||||
def step_pred(self, result, x, t):
|
||||
# TODO(Patrick) better comments + non-PyTorch
|
||||
t = t * torch.ones(x.shape[0], device=x.device)
|
||||
timestep = (t * (len(self.timesteps) - 1)).long()
|
||||
|
||||
sigma = self.discrete_sigmas.to(t.device)[timestep]
|
||||
adjacent_sigma = torch.where(
|
||||
timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(timestep.device)
|
||||
)
|
||||
f = torch.zeros_like(x)
|
||||
G = torch.sqrt(sigma**2 - adjacent_sigma**2)
|
||||
|
||||
f = f - G[:, None, None, None] ** 2 * result
|
||||
|
||||
z = torch.randn_like(x)
|
||||
x_mean = x - f
|
||||
x = x_mean + G[:, None, None, None] * z
|
||||
return x, x_mean
|
||||
|
||||
def step_correct(self, result, x):
|
||||
# TODO(Patrick) better comments + non-PyTorch
|
||||
noise = torch.randn_like(x)
|
||||
grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean()
|
||||
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
|
||||
step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
|
||||
step_size = step_size * torch.ones(x.shape[0], device=x.device)
|
||||
x_mean = x + step_size[:, None, None, None] * result
|
||||
|
||||
x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise
|
||||
|
||||
return x
|
||||
64
src/diffusers/schedulers/scheduling_sde_vp.py
Normal file
64
src/diffusers/schedulers/scheduling_sde_vp.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
|
||||
|
||||
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
|
||||
def __init__(self, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"):
|
||||
super().__init__()
|
||||
self.register_to_config(
|
||||
beta_min=beta_min,
|
||||
beta_max=beta_max,
|
||||
sampling_eps=sampling_eps,
|
||||
)
|
||||
|
||||
self.sigmas = None
|
||||
self.discrete_sigmas = None
|
||||
self.timesteps = None
|
||||
|
||||
def set_timesteps(self, num_inference_steps):
|
||||
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
|
||||
|
||||
def step_pred(self, result, x, t):
|
||||
# TODO(Patrick) better comments + non-PyTorch
|
||||
# postprocess model result
|
||||
log_mean_coeff = (
|
||||
-0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
|
||||
)
|
||||
std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
|
||||
result = -result / std[:, None, None, None]
|
||||
|
||||
# compute
|
||||
dt = -1.0 / len(self.timesteps)
|
||||
|
||||
beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min)
|
||||
drift = -0.5 * beta_t[:, None, None, None] * x
|
||||
diffusion = torch.sqrt(beta_t)
|
||||
drift = drift - diffusion[:, None, None, None] ** 2 * result
|
||||
x_mean = x + drift * dt
|
||||
|
||||
# add noise
|
||||
z = torch.randn_like(x)
|
||||
x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * z
|
||||
|
||||
return x, x_mean
|
||||
115
tests/test_layers_utils.py
Executable file
115
tests/test_layers_utils.py
Executable file
@@ -0,0 +1,115 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers.models.embeddings import get_timestep_embedding
|
||||
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class EmbeddingsTests(unittest.TestCase):
|
||||
def test_timestep_embeddings(self):
|
||||
embedding_dim = 256
|
||||
timesteps = torch.arange(16)
|
||||
|
||||
t1 = get_timestep_embedding(timesteps, embedding_dim)
|
||||
|
||||
# first vector should always be composed only of 0's and 1's
|
||||
assert (t1[0, : embedding_dim // 2] - 0).abs().sum() < 1e-5
|
||||
assert (t1[0, embedding_dim // 2 :] - 1).abs().sum() < 1e-5
|
||||
|
||||
# last element of each vector should be one
|
||||
assert (t1[:, -1] - 1).abs().sum() < 1e-5
|
||||
|
||||
# For large embeddings (e.g. 128) the frequency of every vector is higher
|
||||
# than the previous one which means that the gradients of later vectors are
|
||||
# ALWAYS higher than the previous ones
|
||||
grad_mean = np.abs(np.gradient(t1, axis=-1)).mean(axis=1)
|
||||
|
||||
prev_grad = 0.0
|
||||
for grad in grad_mean:
|
||||
assert grad > prev_grad
|
||||
prev_grad = grad
|
||||
|
||||
def test_timestep_defaults(self):
|
||||
embedding_dim = 16
|
||||
timesteps = torch.arange(10)
|
||||
|
||||
t1 = get_timestep_embedding(timesteps, embedding_dim)
|
||||
t2 = get_timestep_embedding(
|
||||
timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, max_period=10_000
|
||||
)
|
||||
|
||||
assert torch.allclose(t1.cpu(), t2.cpu(), 1e-3)
|
||||
|
||||
def test_timestep_flip_sin_cos(self):
|
||||
embedding_dim = 16
|
||||
timesteps = torch.arange(10)
|
||||
|
||||
t1 = get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=True)
|
||||
t1 = torch.cat([t1[:, embedding_dim // 2 :], t1[:, : embedding_dim // 2]], dim=-1)
|
||||
|
||||
t2 = get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False)
|
||||
|
||||
assert torch.allclose(t1.cpu(), t2.cpu(), 1e-3)
|
||||
|
||||
def test_timestep_downscale_freq_shift(self):
|
||||
embedding_dim = 16
|
||||
timesteps = torch.arange(10)
|
||||
|
||||
t1 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=0)
|
||||
t2 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=1)
|
||||
|
||||
# get cosine half (vectors that are wrapped into cosine)
|
||||
cosine_half = (t1 - t2)[:, embedding_dim // 2 :]
|
||||
|
||||
# cosine needs to be negative
|
||||
assert (np.abs((cosine_half <= 0).numpy()) - 1).sum() < 1e-5
|
||||
|
||||
def test_sinoid_embeddings_hardcoded(self):
|
||||
embedding_dim = 64
|
||||
timesteps = torch.arange(128)
|
||||
|
||||
# standard unet, score_vde
|
||||
t1 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=1, flip_sin_to_cos=False)
|
||||
# glide, ldm
|
||||
t2 = get_timestep_embedding(timesteps, embedding_dim, downscale_freq_shift=0, flip_sin_to_cos=True)
|
||||
# grad-tts
|
||||
t3 = get_timestep_embedding(timesteps, embedding_dim, scale=1000)
|
||||
|
||||
assert torch.allclose(
|
||||
t1[23:26, 47:50].flatten().cpu(),
|
||||
torch.tensor([0.9646, 0.9804, 0.9892, 0.9615, 0.9787, 0.9882, 0.9582, 0.9769, 0.9872]),
|
||||
1e-3,
|
||||
)
|
||||
assert torch.allclose(
|
||||
t2[23:26, 47:50].flatten().cpu(),
|
||||
torch.tensor([0.3019, 0.2280, 0.1716, 0.3146, 0.2377, 0.1790, 0.3272, 0.2474, 0.1864]),
|
||||
1e-3,
|
||||
)
|
||||
assert torch.allclose(
|
||||
t3[23:26, 47:50].flatten().cpu(),
|
||||
torch.tensor([-0.9801, -0.9464, -0.9349, -0.3952, 0.8887, -0.9709, 0.5299, -0.2853, -0.9927]),
|
||||
1e-3,
|
||||
)
|
||||
@@ -22,18 +22,24 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
BDDM,
|
||||
DDIM,
|
||||
DDPM,
|
||||
PNDM,
|
||||
BDDMPipeline,
|
||||
DDIMPipeline,
|
||||
DDIMScheduler,
|
||||
DDPMPipeline,
|
||||
DDPMScheduler,
|
||||
Glide,
|
||||
GlidePipeline,
|
||||
GlideSuperResUNetModel,
|
||||
GlideTextToImageUNetModel,
|
||||
GradTTS,
|
||||
LatentDiffusion,
|
||||
GradTTSPipeline,
|
||||
GradTTSScheduler,
|
||||
LatentDiffusionPipeline,
|
||||
NCSNpp,
|
||||
PNDMPipeline,
|
||||
PNDMScheduler,
|
||||
ScoreSdeVePipeline,
|
||||
ScoreSdeVeScheduler,
|
||||
ScoreSdeVpPipeline,
|
||||
ScoreSdeVpScheduler,
|
||||
UNetGradTTSModel,
|
||||
UNetLDMModel,
|
||||
UNetModel,
|
||||
@@ -107,7 +113,7 @@ class ModelTesterMixin:
|
||||
new_image = new_model(**inputs_dict)
|
||||
|
||||
max_diff = (image - new_image).abs().sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-5, "Models give different forward passes")
|
||||
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
|
||||
|
||||
def test_determinism(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
@@ -425,11 +431,12 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
emb = torch.randn((1, 16, model.config.transformer_dim)).to(torch_device)
|
||||
time_step = torch.tensor([10] * noise.shape[0], device=torch_device)
|
||||
|
||||
model.to(torch_device)
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step, emb)
|
||||
|
||||
output, _ = torch.split(output, 3, dim=1)
|
||||
output_slice = output[0, -1, -3:, -3:].flatten()
|
||||
output_slice = output[0, -1, -3:, -3:].cpu().flatten()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([2.7766, -10.3558, -14.9149, -0.9376, -14.9175, -17.7679, -5.5565, -12.9521, -12.9845])
|
||||
# fmt: on
|
||||
@@ -583,11 +590,11 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32)
|
||||
schedular = DDPMScheduler(timesteps=10)
|
||||
|
||||
ddpm = DDPM(model, schedular)
|
||||
ddpm = DDPMPipeline(model, schedular)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
ddpm.save_pretrained(tmpdirname)
|
||||
new_ddpm = DDPM.from_pretrained(tmpdirname)
|
||||
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
@@ -601,7 +608,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
def test_from_pretrained_hub(self):
|
||||
model_path = "fusing/ddpm-cifar10"
|
||||
|
||||
ddpm = DDPM.from_pretrained(model_path)
|
||||
ddpm = DDPMPipeline.from_pretrained(model_path)
|
||||
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)
|
||||
|
||||
ddpm.noise_scheduler.num_timesteps = 10
|
||||
@@ -624,7 +631,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
noise_scheduler = DDPMScheduler.from_config(model_id)
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
|
||||
ddpm = DDPM(unet=unet, noise_scheduler=noise_scheduler)
|
||||
ddpm = DDPMPipeline(unet=unet, noise_scheduler=noise_scheduler)
|
||||
image = ddpm(generator=generator)
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
@@ -641,7 +648,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
unet = UNetModel.from_pretrained(model_id)
|
||||
noise_scheduler = DDIMScheduler(tensor_format="pt")
|
||||
|
||||
ddim = DDIM(unet=unet, noise_scheduler=noise_scheduler)
|
||||
ddim = DDIMPipeline(unet=unet, noise_scheduler=noise_scheduler)
|
||||
image = ddim(generator=generator, eta=0.0)
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
@@ -660,7 +667,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
unet = UNetModel.from_pretrained(model_id)
|
||||
noise_scheduler = PNDMScheduler(tensor_format="pt")
|
||||
|
||||
pndm = PNDM(unet=unet, noise_scheduler=noise_scheduler)
|
||||
pndm = PNDMPipeline(unet=unet, noise_scheduler=noise_scheduler)
|
||||
image = pndm(generator=generator)
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
@@ -672,9 +679,10 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
||||
@slow
|
||||
@unittest.skip("Skipping for now as it takes too long")
|
||||
def test_ldm_text2img(self):
|
||||
model_id = "fusing/latent-diffusion-text2im-large"
|
||||
ldm = LatentDiffusion.from_pretrained(model_id)
|
||||
ldm = LatentDiffusionPipeline.from_pretrained(model_id)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.manual_seed(0)
|
||||
@@ -686,10 +694,25 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458])
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_ldm_text2img_fast(self):
|
||||
model_id = "fusing/latent-diffusion-text2im-large"
|
||||
ldm = LatentDiffusionPipeline.from_pretrained(model_id)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.manual_seed(0)
|
||||
image = ldm([prompt], generator=generator, num_inference_steps=1)
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
|
||||
assert image.shape == (1, 3, 256, 256)
|
||||
expected_slice = torch.tensor([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_glide_text2img(self):
|
||||
model_id = "fusing/glide-base"
|
||||
glide = Glide.from_pretrained(model_id)
|
||||
glide = GlidePipeline.from_pretrained(model_id)
|
||||
|
||||
prompt = "a pencil sketch of a corgi"
|
||||
generator = torch.manual_seed(0)
|
||||
@@ -704,22 +727,61 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
@slow
|
||||
def test_grad_tts(self):
|
||||
model_id = "fusing/grad-tts-libri-tts"
|
||||
grad_tts = GradTTS.from_pretrained(model_id)
|
||||
grad_tts = GradTTSPipeline.from_pretrained(model_id)
|
||||
noise_scheduler = GradTTSScheduler()
|
||||
grad_tts.noise_scheduler = noise_scheduler
|
||||
|
||||
text = "Hello world, I missed you so much."
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
# generate mel spectograms using text
|
||||
mel_spec = grad_tts(text)
|
||||
mel_spec = grad_tts(text, generator=generator)
|
||||
|
||||
assert mel_spec.shape == (1, 256, 256, 3)
|
||||
expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784])
|
||||
assert (mel_spec.flatten() - expected_slice).abs().max() < 1e-2
|
||||
assert mel_spec.shape == (1, 80, 143)
|
||||
expected_slice = torch.tensor(
|
||||
[-6.7584, -6.8347, -6.3293, -6.6437, -6.7233, -6.4684, -6.1187, -6.3172, -6.6890]
|
||||
)
|
||||
assert (mel_spec[0, :3, :3].cpu().flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_score_sde_ve_pipeline(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
model = NCSNpp.from_pretrained("fusing/ffhq_ncsnpp")
|
||||
scheduler = ScoreSdeVeScheduler.from_config("fusing/ffhq_ncsnpp")
|
||||
|
||||
sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)
|
||||
|
||||
image = sde_ve(num_inference_steps=2)
|
||||
|
||||
expected_image_sum = 3382810112.0
|
||||
expected_image_mean = 1075.366455078125
|
||||
|
||||
assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
|
||||
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4
|
||||
|
||||
@slow
|
||||
def test_score_sde_vp_pipeline(self):
|
||||
|
||||
model = NCSNpp.from_pretrained("fusing/cifar10-ddpmpp-vp")
|
||||
scheduler = ScoreSdeVpScheduler.from_config("fusing/cifar10-ddpmpp-vp")
|
||||
|
||||
sde_vp = ScoreSdeVpPipeline(model=model, scheduler=scheduler)
|
||||
|
||||
torch.manual_seed(0)
|
||||
image = sde_vp(num_inference_steps=10)
|
||||
|
||||
expected_image_sum = 4183.2012
|
||||
expected_image_mean = 1.3617
|
||||
|
||||
assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
|
||||
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4
|
||||
|
||||
def test_module_from_pipeline(self):
|
||||
model = DiffWave(num_res_layers=4)
|
||||
noise_scheduler = DDPMScheduler(timesteps=12)
|
||||
|
||||
bddm = BDDM(model, noise_scheduler)
|
||||
bddm = BDDMPipeline(model, noise_scheduler)
|
||||
|
||||
# check if the library name for the diffwave moduel is set to pipeline module
|
||||
self.assertTrue(bddm.config["diffwave"][0] == "pipeline_bddm")
|
||||
@@ -727,6 +789,6 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
# check if we can save and load the pipeline
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
bddm.save_pretrained(tmpdirname)
|
||||
_ = BDDM.from_pretrained(tmpdirname)
|
||||
_ = BDDMPipeline.from_pretrained(tmpdirname)
|
||||
# check if the same works using the DifusionPipeline class
|
||||
_ = DiffusionPipeline.from_pretrained(tmpdirname)
|
||||
|
||||
Reference in New Issue
Block a user