From 31712deac3c679fd010e8e65f0b8b8aea6217742 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 15 Jun 2022 11:16:13 +0200 Subject: [PATCH] add unet grad tts --- src/diffusers/__init__.py | 1 + src/diffusers/models/__init__.py | 1 + src/diffusers/models/unet_grad_tts.py | 233 ++++++++++++++++++++++++++ 3 files changed, 235 insertions(+) create mode 100644 src/diffusers/models/unet_grad_tts.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index e374e3aed2..2f4d2ab6dc 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -8,6 +8,7 @@ from .modeling_utils import ModelMixin from .models.unet import UNetModel from .models.unet_glide import GLIDEUNetModel, GLIDESuperResUNetModel, GLIDETextToImageUNetModel from .models.unet_ldm import UNetLDMModel +from .models.unet_grad_tts import UNetGradTTSModel from .pipeline_utils import DiffusionPipeline from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, PNDM, BDDM from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin, PNDMScheduler diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index dc98e2bb5e..9104bb9031 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -19,3 +19,4 @@ from .unet import UNetModel from .unet_glide import GLIDEUNetModel, GLIDESuperResUNetModel, GLIDETextToImageUNetModel from .unet_ldm import UNetLDMModel +from .unet_grad_tts import UNetGradTTSModel \ No newline at end of file diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py new file mode 100644 index 0000000000..de2d6aa2f1 --- /dev/null +++ b/src/diffusers/models/unet_grad_tts.py @@ -0,0 +1,233 @@ +import math + +import torch + +try: + from einops import rearrange, repeat +except: + print("Einops is not installed") + pass + +from ..configuration_utils import ConfigMixin +from ..modeling_utils import ModelMixin + +class Mish(torch.nn.Module): + def forward(self, x): + return x * torch.tanh(torch.nn.functional.softplus(x)) + + +class Upsample(torch.nn.Module): + def __init__(self, dim): + super(Upsample, self).__init__() + self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Downsample(torch.nn.Module): + def __init__(self, dim): + super(Downsample, self).__init__() + self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Rezero(torch.nn.Module): + def __init__(self, fn): + super(Rezero, self).__init__() + self.fn = fn + self.g = torch.nn.Parameter(torch.zeros(1)) + + def forward(self, x): + return self.fn(x) * self.g + + +class Block(torch.nn.Module): + def __init__(self, dim, dim_out, groups=8): + super(Block, self).__init__() + self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3, + padding=1), torch.nn.GroupNorm( + groups, dim_out), Mish()) + + def forward(self, x, mask): + output = self.block(x * mask) + return output * mask + + +class ResnetBlock(torch.nn.Module): + def __init__(self, dim, dim_out, time_emb_dim, groups=8): + super(ResnetBlock, self).__init__() + self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, + dim_out)) + + self.block1 = Block(dim, dim_out, groups=groups) + self.block2 = Block(dim_out, dim_out, groups=groups) + if dim != dim_out: + self.res_conv = torch.nn.Conv2d(dim, dim_out, 1) + else: + self.res_conv = torch.nn.Identity() + + def forward(self, x, mask, time_emb): + h = self.block1(x, mask) + h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1) + h = self.block2(h, mask) + output = h + self.res_conv(x * mask) + return output + + +class LinearAttention(torch.nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super(LinearAttention, self).__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', + heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', + heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class Residual(torch.nn.Module): + def __init__(self, fn): + super(Residual, self).__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + output = self.fn(x, *args, **kwargs) + x + return output + + +class SinusoidalPosEmb(torch.nn.Module): + def __init__(self, dim): + super(SinusoidalPosEmb, self).__init__() + self.dim = dim + + def forward(self, x, scale=1000): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class UNetGradTTSModel(ModelMixin, ConfigMixin): + def __init__( + self, + dim, + dim_mults=(1, 2, 4), + groups=8, + n_spks=None, + spk_emb_dim=64, + n_feats=80, + pe_scale=1000 + ): + super(UNetGradTTSModel, self).__init__() + + self.register( + dim=dim, + dim_mults=dim_mults, + groups=groups, + n_spks=n_spks, + spk_emb_dim=spk_emb_dim, + n_feats=n_feats, + pe_scale=pe_scale + ) + + self.dim = dim + self.dim_mults = dim_mults + self.groups = groups + self.n_spks = n_spks if not isinstance(n_spks, type(None)) else 1 + self.spk_emb_dim = spk_emb_dim + self.pe_scale = pe_scale + + if n_spks > 1: + self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), + torch.nn.Linear(spk_emb_dim * 4, n_feats)) + self.time_pos_emb = SinusoidalPosEmb(dim) + self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), + torch.nn.Linear(dim * 4, dim)) + + dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + self.downs = torch.nn.ModuleList([]) + self.ups = torch.nn.ModuleList([]) + num_resolutions = len(in_out) + + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + self.downs.append(torch.nn.ModuleList([ + ResnetBlock(dim_in, dim_out, time_emb_dim=dim), + ResnetBlock(dim_out, dim_out, time_emb_dim=dim), + Residual(Rezero(LinearAttention(dim_out))), + Downsample(dim_out) if not is_last else torch.nn.Identity()])) + + mid_dim = dims[-1] + self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) + self.mid_attn = Residual(Rezero(LinearAttention(mid_dim))) + self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): + self.ups.append(torch.nn.ModuleList([ + ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim), + ResnetBlock(dim_in, dim_in, time_emb_dim=dim), + Residual(Rezero(LinearAttention(dim_in))), + Upsample(dim_in)])) + self.final_block = Block(dim, dim) + self.final_conv = torch.nn.Conv2d(dim, 1, 1) + + def forward(self, x, mask, mu, t, spk=None): + if not isinstance(spk, type(None)): + s = self.spk_mlp(spk) + + t = self.time_pos_emb(t, scale=self.pe_scale) + t = self.mlp(t) + + if self.n_spks < 2: + x = torch.stack([mu, x], 1) + else: + s = s.unsqueeze(-1).repeat(1, 1, x.shape[-1]) + x = torch.stack([mu, x, s], 1) + mask = mask.unsqueeze(1) + + hiddens = [] + masks = [mask] + for resnet1, resnet2, attn, downsample in self.downs: + mask_down = masks[-1] + x = resnet1(x, mask_down, t) + x = resnet2(x, mask_down, t) + x = attn(x) + hiddens.append(x) + x = downsample(x * mask_down) + masks.append(mask_down[:, :, :, ::2]) + + masks = masks[:-1] + mask_mid = masks[-1] + x = self.mid_block1(x, mask_mid, t) + x = self.mid_attn(x) + x = self.mid_block2(x, mask_mid, t) + + for resnet1, resnet2, attn, upsample in self.ups: + mask_up = masks.pop() + x = torch.cat((x, hiddens.pop()), dim=1) + x = resnet1(x, mask_up, t) + x = resnet2(x, mask_up, t) + x = attn(x) + x = upsample(x * mask_up) + + x = self.final_block(x, mask) + output = self.final_conv(x * mask) + + return (output * mask).squeeze(1) \ No newline at end of file