From 23d50522e7427ff74bc103af956fc114e4fc1969 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 15 Jun 2022 09:41:23 +0200 Subject: [PATCH 01/11] remove unused files --- .../pipelines/configuration_ldmbert.py | 146 --- src/diffusers/pipelines/modeling_vae.py | 859 ------------------ 2 files changed, 1005 deletions(-) delete mode 100644 src/diffusers/pipelines/configuration_ldmbert.py delete mode 100644 src/diffusers/pipelines/modeling_vae.py diff --git a/src/diffusers/pipelines/configuration_ldmbert.py b/src/diffusers/pipelines/configuration_ldmbert.py deleted file mode 100644 index 00d3ac907e..0000000000 --- a/src/diffusers/pipelines/configuration_ldmbert.py +++ /dev/null @@ -1,146 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" LDMBERT model configuration""" - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - - -logger = logging.get_logger(__name__) - -LDMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "ldm-bert": "https://huggingface.co/ldm-bert/resolve/main/config.json", -} - - -class LDMBertConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`LDMBertModel`]. It is used to instantiate a - LDMBERT model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of the LDMBERT - [facebook/ldmbert-large](https://huggingface.co/facebook/ldmbert-large) architecture. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 50265): - Vocabulary size of the LDMBERT model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`LDMBertModel`] or [`TFLDMBertModel`]. - d_model (`int`, *optional*, defaults to 1024): - Dimensionality of the layers and the pooler layer. - encoder_layers (`int`, *optional*, defaults to 12): - Number of encoder layers. - decoder_layers (`int`, *optional*, defaults to 12): - Number of decoder layers. - encoder_attention_heads (`int`, *optional*, defaults to 16): - Number of attention heads for each attention layer in the Transformer encoder. - decoder_attention_heads (`int`, *optional*, defaults to 16): - Number of attention heads for each attention layer in the Transformer decoder. - decoder_ffn_dim (`int`, *optional*, defaults to 4096): - Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. - encoder_ffn_dim (`int`, *optional*, defaults to 4096): - Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. - activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): - The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, - `"relu"`, `"silu"` and `"gelu_new"` are supported. - dropout (`float`, *optional*, defaults to 0.1): - The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - activation_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for activations inside the fully connected layer. - classifier_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for classifier. - max_position_embeddings (`int`, *optional*, defaults to 1024): - The maximum sequence length that this model might ever be used with. Typically set this to something large - just in case (e.g., 512 or 1024 or 2048). - init_std (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - encoder_layerdrop: (`float`, *optional*, defaults to 0.0): - The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) - for more details. - decoder_layerdrop: (`float`, *optional*, defaults to 0.0): - The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) - for more details. - scale_embedding (`bool`, *optional*, defaults to `False`): - Scale embeddings by diving by sqrt(d_model). - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). - num_labels: (`int`, *optional*, defaults to 3): - The number of labels to use in [`LDMBertForSequenceClassification`]. - forced_eos_token_id (`int`, *optional*, defaults to 2): - The id of the token to force as the last generated token when `max_length` is reached. Usually set to - `eos_token_id`. - - Example: - - ```python - >>> from transformers import LDMBertModel, LDMBertConfig - - >>> # Initializing a LDMBERT facebook/ldmbert-large style configuration - >>> configuration = LDMBertConfig() - - >>> # Initializing a model from the facebook/ldmbert-large style configuration - >>> model = LDMBertModel(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - model_type = "ldmbert" - keys_to_ignore_at_inference = ["past_key_values"] - attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} - - def __init__( - self, - vocab_size=30522, - max_position_embeddings=77, - encoder_layers=32, - encoder_ffn_dim=5120, - encoder_attention_heads=8, - head_dim=64, - encoder_layerdrop=0.0, - activation_function="gelu", - d_model=1280, - dropout=0.1, - attention_dropout=0.0, - activation_dropout=0.0, - init_std=0.02, - classifier_dropout=0.0, - scale_embedding=False, - use_cache=True, - pad_token_id=0, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.d_model = d_model - self.encoder_ffn_dim = encoder_ffn_dim - self.encoder_layers = encoder_layers - self.encoder_attention_heads = encoder_attention_heads - self.head_dim = head_dim - self.dropout = dropout - self.attention_dropout = attention_dropout - self.activation_dropout = activation_dropout - self.activation_function = activation_function - self.init_std = init_std - self.encoder_layerdrop = encoder_layerdrop - self.classifier_dropout = classifier_dropout - self.use_cache = use_cache - self.num_hidden_layers = encoder_layers - self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True - - super().__init__(pad_token_id=pad_token_id, **kwargs) diff --git a/src/diffusers/pipelines/modeling_vae.py b/src/diffusers/pipelines/modeling_vae.py deleted file mode 100644 index 7b299eee5e..0000000000 --- a/src/diffusers/pipelines/modeling_vae.py +++ /dev/null @@ -1,859 +0,0 @@ -# pytorch_diffusion + derived encoder decoder -import math - -import numpy as np -import torch -import torch.nn as nn - -import tqdm -from diffusers import DiffusionPipeline -from diffusers.configuration_utils import ConfigMixin -from diffusers.modeling_utils import ModelMixin - - -def get_timestep_embedding(timesteps, embedding_dim): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: - From Fairseq. - Build sinusoidal embeddings. - This matches the implementation in tensor2tensor, but differs slightly - from the description in Section 3.5 of "Attention Is All You Need". - """ - assert len(timesteps.shape) == 1 - - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) - emb = emb.to(device=timesteps.device) - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb - - -def nonlinearity(x): - # swish - return x * torch.sigmoid(x) - - -def Normalize(in_channels): - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - - -class Upsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) - - def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - if self.with_conv: - x = self.conv(x) - return x - - -class Downsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) - - def forward(self, x): - if self.with_conv: - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - else: - x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) - return x - - -class ResnetBlock(nn.Module): - def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, out_channels) - self.norm2 = Normalize(out_channels) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) - - if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] - - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - - return x + h - - -class AttnBlock(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q.shape - q = q.reshape(b, c, h * w) - q = q.permute(0, 2, 1) # b,hw,c - k = k.reshape(b, c, h * w) # b,c,hw - w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c) ** (-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b, c, h * w) - w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b, c, h, w) - - h_ = self.proj_out(h_) - - return x + h_ - - -class Model(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - use_timestep=True, - ): - super().__init__() - self.ch = ch - self.temb_ch = self.ch * 4 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - self.use_timestep = use_timestep - if self.use_timestep: - # timestep embedding - self.temb = nn.Module() - self.temb.dense = nn.ModuleList( - [ - torch.nn.Linear(self.ch, self.temb_ch), - torch.nn.Linear(self.temb_ch, self.temb_ch), - ] - ) - - # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) - - curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - skip_in = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - if i_block == self.num_res_blocks: - skip_in = ch * in_ch_mult[i_level] - block.append( - ResnetBlock( - in_channels=block_in + skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) - - def forward(self, x, t=None): - # assert x.shape[2] == x.shape[3] == self.resolution - - if self.use_timestep: - # timestep embedding - assert t is not None - temb = get_timestep_embedding(t, self.ch) - temb = self.temb.dense[0](temb) - temb = nonlinearity(temb) - temb = self.temb.dense[1](temb) - else: - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class Encoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - double_z=True, - **ignore_kwargs, - ): - super().__init__() - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) - - curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x): - # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) - - # timestep embedding - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class Decoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - give_pre_end=False, - **ignorekwargs, - ): - super().__init__() - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - self.give_pre_end = give_pre_end - - # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,) + tuple(ch_mult) - block_in = ch * ch_mult[self.num_resolutions - 1] - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.z_shape = (1, z_channels, curr_res, curr_res) - print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) - - # z to block_in - self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout - ) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - block.append( - ResnetBlock( - in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(AttnBlock(block_in)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) - - def forward(self, z): - # assert z.shape[1:] == self.z_shape[1:] - self.last_z_shape = z.shape - - # timestep embedding - temb = None - - # z to block_in - h = self.conv_in(z) - - # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](h, temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - if self.give_pre_end: - return h - - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class VectorQuantizer(nn.Module): - """ - Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly - avoids costly matrix multiplications and allows for post-hoc remapping of indices. - """ - - # NOTE: due to a bug the beta term was applied to the wrong term. for - # backwards compatibility we use the buggy version by default, but you can - # specify legacy=False to fix it. - def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True): - super().__init__() - self.n_e = n_e - self.e_dim = e_dim - self.beta = beta - self.legacy = legacy - - self.embedding = nn.Embedding(self.n_e, self.e_dim) - self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) - - self.remap = remap - if self.remap is not None: - self.register_buffer("used", torch.tensor(np.load(self.remap))) - self.re_embed = self.used.shape[0] - self.unknown_index = unknown_index # "random" or "extra" or integer - if self.unknown_index == "extra": - self.unknown_index = self.re_embed - self.re_embed = self.re_embed + 1 - print( - f"Remapping {self.n_e} indices to {self.re_embed} indices. " - f"Using {self.unknown_index} for unknown indices." - ) - else: - self.re_embed = n_e - - self.sane_index_shape = sane_index_shape - - def remap_to_used(self, inds): - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used.to(inds) - match = (inds[:, :, None] == used[None, None, ...]).long() - new = match.argmax(-1) - unknown = match.sum(2) < 1 - if self.unknown_index == "random": - new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) - else: - new[unknown] = self.unknown_index - return new.reshape(ishape) - - def unmap_to_all(self, inds): - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used.to(inds) - if self.re_embed > self.used.shape[0]: # extra token - inds[inds >= self.used.shape[0]] = 0 # simply set to zero - back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) - return back.reshape(ishape) - - def forward(self, z, temp=None, rescale_logits=False, return_logits=False): - assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" - assert rescale_logits == False, "Only for interface compatible with Gumbel" - assert return_logits == False, "Only for interface compatible with Gumbel" - # reshape z -> (batch, height, width, channel) and flatten - z = rearrange(z, "b c h w -> b h w c").contiguous() - z_flattened = z.view(-1, self.e_dim) - # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - - d = ( - torch.sum(z_flattened**2, dim=1, keepdim=True) - + torch.sum(self.embedding.weight**2, dim=1) - - 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")) - ) - - min_encoding_indices = torch.argmin(d, dim=1) - z_q = self.embedding(min_encoding_indices).view(z.shape) - perplexity = None - min_encodings = None - - # compute loss for embedding - if not self.legacy: - loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) - else: - loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) - - # preserve gradients - z_q = z + (z_q - z).detach() - - # reshape back to match original input shape - z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() - - if self.remap is not None: - min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis - min_encoding_indices = self.remap_to_used(min_encoding_indices) - min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten - - if self.sane_index_shape: - min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) - - return z_q, loss, (perplexity, min_encodings, min_encoding_indices) - - def get_codebook_entry(self, indices, shape): - # shape specifying (batch, height, width, channel) - if self.remap is not None: - indices = indices.reshape(shape[0], -1) # add batch axis - indices = self.unmap_to_all(indices) - indices = indices.reshape(-1) # flatten again - - # get quantized latent vectors - z_q = self.embedding(indices) - - if shape is not None: - z_q = z_q.view(shape) - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - return z_q - - -class VQModel(ModelMixin, ConfigMixin): - def __init__( - self, - ch, - out_ch, - num_res_blocks, - attn_resolutions, - in_channels, - resolution, - z_channels, - n_embed, - embed_dim, - remap=None, - sane_index_shape=False, # tell vector quantizer to return indices as bhw - ch_mult=(1, 2, 4, 8), - dropout=0.0, - double_z=True, - resamp_with_conv=True, - give_pre_end=False, - ): - super().__init__() - - # register all __init__ params with self.register - self.register( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - n_embed=n_embed, - embed_dim=embed_dim, - remap=remap, - sane_index_shape=sane_index_shape, - ch_mult=ch_mult, - dropout=dropout, - double_z=double_z, - resamp_with_conv=resamp_with_conv, - give_pre_end=give_pre_end, - ) - - # pass init params to Encoder - self.encoder = Encoder( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - ch_mult=ch_mult, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - double_z=double_z, - give_pre_end=give_pre_end, - ) - - self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape) - - # pass init params to Decoder - self.decoder = Decoder( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - ch_mult=ch_mult, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - give_pre_end=give_pre_end, - ) - - def encode(self, x): - h = self.encoder(x) - h = self.quant_conv(h) - return h - - def decode(self, h, force_not_quantize=False): - # also go through quantization layer - if not force_not_quantize: - quant, emb_loss, info = self.quantize(h) - else: - quant = h - quant = self.post_quant_conv(quant) - dec = self.decoder(quant) - return dec - - -class DiagonalGaussianDistribution(object): - def __init__(self, parameters, deterministic=False): - self.parameters = parameters - self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) - self.logvar = torch.clamp(self.logvar, -30.0, 20.0) - self.deterministic = deterministic - self.std = torch.exp(0.5 * self.logvar) - self.var = torch.exp(self.logvar) - if self.deterministic: - self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) - - def sample(self): - x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) - return x - - def kl(self, other=None): - if self.deterministic: - return torch.Tensor([0.0]) - else: - if other is None: - return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) - else: - return 0.5 * torch.sum( - torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - - 1.0 - - self.logvar - + other.logvar, - dim=[1, 2, 3], - ) - - def nll(self, sample, dims=[1, 2, 3]): - if self.deterministic: - return torch.Tensor([0.0]) - logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) - - def mode(self): - return self.mean - - -class AutoencoderKL(ModelMixin, ConfigMixin): - def __init__( - self, - ch, - out_ch, - num_res_blocks, - attn_resolutions, - in_channels, - resolution, - z_channels, - embed_dim, - remap=None, - sane_index_shape=False, # tell vector quantizer to return indices as bhw - ch_mult=(1, 2, 4, 8), - dropout=0.0, - double_z=True, - resamp_with_conv=True, - give_pre_end=False, - ): - super().__init__() - - # register all __init__ params with self.register - self.register( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - embed_dim=embed_dim, - remap=remap, - sane_index_shape=sane_index_shape, - ch_mult=ch_mult, - dropout=dropout, - double_z=double_z, - resamp_with_conv=resamp_with_conv, - give_pre_end=give_pre_end, - ) - - # pass init params to Encoder - self.encoder = Encoder( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - ch_mult=ch_mult, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - double_z=double_z, - give_pre_end=give_pre_end, - ) - - # pass init params to Decoder - self.decoder = Decoder( - ch=ch, - out_ch=out_ch, - num_res_blocks=num_res_blocks, - attn_resolutions=attn_resolutions, - in_channels=in_channels, - resolution=resolution, - z_channels=z_channels, - ch_mult=ch_mult, - dropout=dropout, - resamp_with_conv=resamp_with_conv, - give_pre_end=give_pre_end, - ) - - self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1) - - def encode(self, x): - h = self.encoder(x) - moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) - return posterior - - def decode(self, z): - z = self.post_quant_conv(z) - dec = self.decoder(z) - return dec - - def forward(self, input, sample_posterior=True): - posterior = self.encode(input) - if sample_posterior: - z = posterior.sample() - else: - z = posterior.mode() - dec = self.decode(z) - return dec, posterior From 76f0f1d453a074df113708eebe12c4bd00f19560 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 15 Jun 2022 09:44:18 +0200 Subject: [PATCH 02/11] update speech checkpoint name --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7c0a2fe71f..210efac201 100644 --- a/README.md +++ b/README.md @@ -197,7 +197,7 @@ from diffusers import BDDM, DiffusionPipeline torch_device = "cuda" # load the BDDM pipeline -bddm = DiffusionPipeline.from_pretrained("fusing/diffwave-vocoder") +bddm = DiffusionPipeline.from_pretrained("fusing/diffwave-vocoder-ljspeech") # load tacotron2 to get the mel spectograms tacotron2 = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_tacotron2', model_math='fp16') From ca94e36c97a7ac0976522f7ddeaa063ab201eda0 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 15 Jun 2022 10:12:55 +0200 Subject: [PATCH 03/11] fix LatentDiffusion --- .../pipelines/pipeline_latent_diffusion.py | 52 +++++-------------- 1 file changed, 12 insertions(+), 40 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_latent_diffusion.py b/src/diffusers/pipelines/pipeline_latent_diffusion.py index 20153aea40..fe1ed4ca30 100644 --- a/src/diffusers/pipelines/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/pipeline_latent_diffusion.py @@ -900,11 +900,12 @@ class LatentDiffusion(DiffusionPipeline): num_trained_timesteps = self.noise_scheduler.timesteps inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) - image = torch.randn( + image = self.noise_scheduler.sample_noise( (batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size), + device=torch_device, generator=generator, ) - image = image.to(torch_device) + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding @@ -937,46 +938,17 @@ class LatentDiffusion(DiffusionPipeline): pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2) pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond) - # 2. get actual t and t-1 - train_step = inference_step_times[t] - prev_train_step = inference_step_times[t - 1] if t > 0 else -1 + # 2. predict previous mean of image x_t-1 + pred_prev_image = self.noise_scheduler.step(pred_noise_t, image, t, num_inference_steps, eta) - # 3. compute alphas, betas - alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step) - alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step) - beta_prod_t = 1 - alpha_prod_t - beta_prod_t_prev = 1 - alpha_prod_t_prev + # 3. optionally sample variance + variance = 0 + if eta > 0: + noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) + variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise - # 4. Compute predicted previous image from predicted noise - # First: compute predicted original image from predicted noise also called - # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt() - - # Second: Compute variance: "sigma_t(η)" -> see formula (16) - # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) - std_dev_t = (beta_prod_t_prev / beta_prod_t).sqrt() * (1 - alpha_prod_t / alpha_prod_t_prev).sqrt() - std_dev_t = eta * std_dev_t - - # Third: Compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2).sqrt() * pred_noise_t - - # Forth: Compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_prev_image = alpha_prod_t_prev.sqrt() * pred_original_image + pred_image_direction - - # 5. Sample x_t-1 image optionally if η > 0.0 by adding noise to pred_prev_image - # Note: eta = 1.0 essentially corresponds to DDPM - if eta > 0.0: - noise = torch.randn( - (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), - generator=generator, - ) - noise = noise.to(torch_device) - prev_image = pred_prev_image + std_dev_t * noise - else: - prev_image = pred_prev_image - - # 6. Set current image to prev_image: x_t -> x_t-1 - image = prev_image + # 4. set current image to prev_image: x_t -> x_t-1 + image = pred_prev_image + variance # scale and decode image with vae image = 1 / 0.18215 * image From d4c2bcf8a385b54809b3d14026e1fc26622b9132 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 15 Jun 2022 10:15:54 +0200 Subject: [PATCH 04/11] fix nois in ldm --- src/diffusers/pipelines/pipeline_latent_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_latent_diffusion.py b/src/diffusers/pipelines/pipeline_latent_diffusion.py index fe1ed4ca30..88751312cf 100644 --- a/src/diffusers/pipelines/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/pipeline_latent_diffusion.py @@ -944,7 +944,7 @@ class LatentDiffusion(DiffusionPipeline): # 3. optionally sample variance variance = 0 if eta > 0: - noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) + noise = torch.randn(image.shape, generator=generator, device=image.device) variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise # 4. set current image to prev_image: x_t -> x_t-1 From a3784522a8af8f285dda4bb3c90a458e12d8fa37 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 15 Jun 2022 10:16:48 +0200 Subject: [PATCH 05/11] fix initial image in ddim --- src/diffusers/pipelines/pipeline_latent_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_latent_diffusion.py b/src/diffusers/pipelines/pipeline_latent_diffusion.py index 88751312cf..71530ba4b7 100644 --- a/src/diffusers/pipelines/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/pipeline_latent_diffusion.py @@ -900,7 +900,7 @@ class LatentDiffusion(DiffusionPipeline): num_trained_timesteps = self.noise_scheduler.timesteps inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) - image = self.noise_scheduler.sample_noise( + image = torch.randn( (batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size), device=torch_device, generator=generator, From cdb3c4931bdff39eff4aa863e0d9b3c6ceace047 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 15 Jun 2022 10:17:36 +0200 Subject: [PATCH 06/11] fix device for ldm --- src/diffusers/pipelines/pipeline_latent_diffusion.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_latent_diffusion.py b/src/diffusers/pipelines/pipeline_latent_diffusion.py index 71530ba4b7..e77d80cb94 100644 --- a/src/diffusers/pipelines/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/pipeline_latent_diffusion.py @@ -902,9 +902,8 @@ class LatentDiffusion(DiffusionPipeline): image = torch.randn( (batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size), - device=torch_device, generator=generator, - ) + ).to(torch_device) # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding From 8fdecfab00ad672d4070c2c672ac0442f2f1ffca Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 15 Jun 2022 10:18:13 +0200 Subject: [PATCH 07/11] fix noise device --- src/diffusers/pipelines/pipeline_latent_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_latent_diffusion.py b/src/diffusers/pipelines/pipeline_latent_diffusion.py index e77d80cb94..bd9688cc85 100644 --- a/src/diffusers/pipelines/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/pipeline_latent_diffusion.py @@ -943,7 +943,7 @@ class LatentDiffusion(DiffusionPipeline): # 3. optionally sample variance variance = 0 if eta > 0: - noise = torch.randn(image.shape, generator=generator, device=image.device) + noise = torch.randn(image.shape, generator=generator)to(image.device) variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise # 4. set current image to prev_image: x_t -> x_t-1 From 01b238d0de58ec1cf21aa8f651a412480d677636 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 15 Jun 2022 10:18:32 +0200 Subject: [PATCH 08/11] fix typo --- src/diffusers/pipelines/pipeline_latent_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_latent_diffusion.py b/src/diffusers/pipelines/pipeline_latent_diffusion.py index bd9688cc85..10ee253f44 100644 --- a/src/diffusers/pipelines/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/pipeline_latent_diffusion.py @@ -943,7 +943,7 @@ class LatentDiffusion(DiffusionPipeline): # 3. optionally sample variance variance = 0 if eta > 0: - noise = torch.randn(image.shape, generator=generator)to(image.device) + noise = torch.randn(image.shape, generator=generator).to(image.device) variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise # 4. set current image to prev_image: x_t -> x_t-1 From 14a2201f7748e939632b18c6990230b11edcc1ee Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 15 Jun 2022 10:42:37 +0200 Subject: [PATCH 09/11] update ldm example --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 210efac201..92c74a45ed 100644 --- a/README.md +++ b/README.md @@ -166,13 +166,14 @@ image_pil.save("test.png") #### **Text to Image generation with Latent Diffusion** +_Note: To use latent diffusion install transformers from [this branch](https://github.com/patil-suraj/transformers/tree/ldm-bert)._ + ```python from diffusers import DiffusionPipeline ldm = DiffusionPipeline.from_pretrained("fusing/latent-diffusion-text2im-large") -generator = torch.Generator() -generator = generator.manual_seed(6694729458485568) +generator = torch.manual_seed(42) prompt = "A painting of a squirrel eating a burger" image = ldm([prompt], generator=generator, eta=0.3, guidance_scale=6.0, num_inference_steps=50) From 31712deac3c679fd010e8e65f0b8b8aea6217742 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 15 Jun 2022 11:16:13 +0200 Subject: [PATCH 10/11] add unet grad tts --- src/diffusers/__init__.py | 1 + src/diffusers/models/__init__.py | 1 + src/diffusers/models/unet_grad_tts.py | 233 ++++++++++++++++++++++++++ 3 files changed, 235 insertions(+) create mode 100644 src/diffusers/models/unet_grad_tts.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index e374e3aed2..2f4d2ab6dc 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -8,6 +8,7 @@ from .modeling_utils import ModelMixin from .models.unet import UNetModel from .models.unet_glide import GLIDEUNetModel, GLIDESuperResUNetModel, GLIDETextToImageUNetModel from .models.unet_ldm import UNetLDMModel +from .models.unet_grad_tts import UNetGradTTSModel from .pipeline_utils import DiffusionPipeline from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, PNDM, BDDM from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin, PNDMScheduler diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index dc98e2bb5e..9104bb9031 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -19,3 +19,4 @@ from .unet import UNetModel from .unet_glide import GLIDEUNetModel, GLIDESuperResUNetModel, GLIDETextToImageUNetModel from .unet_ldm import UNetLDMModel +from .unet_grad_tts import UNetGradTTSModel \ No newline at end of file diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py new file mode 100644 index 0000000000..de2d6aa2f1 --- /dev/null +++ b/src/diffusers/models/unet_grad_tts.py @@ -0,0 +1,233 @@ +import math + +import torch + +try: + from einops import rearrange, repeat +except: + print("Einops is not installed") + pass + +from ..configuration_utils import ConfigMixin +from ..modeling_utils import ModelMixin + +class Mish(torch.nn.Module): + def forward(self, x): + return x * torch.tanh(torch.nn.functional.softplus(x)) + + +class Upsample(torch.nn.Module): + def __init__(self, dim): + super(Upsample, self).__init__() + self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Downsample(torch.nn.Module): + def __init__(self, dim): + super(Downsample, self).__init__() + self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Rezero(torch.nn.Module): + def __init__(self, fn): + super(Rezero, self).__init__() + self.fn = fn + self.g = torch.nn.Parameter(torch.zeros(1)) + + def forward(self, x): + return self.fn(x) * self.g + + +class Block(torch.nn.Module): + def __init__(self, dim, dim_out, groups=8): + super(Block, self).__init__() + self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3, + padding=1), torch.nn.GroupNorm( + groups, dim_out), Mish()) + + def forward(self, x, mask): + output = self.block(x * mask) + return output * mask + + +class ResnetBlock(torch.nn.Module): + def __init__(self, dim, dim_out, time_emb_dim, groups=8): + super(ResnetBlock, self).__init__() + self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, + dim_out)) + + self.block1 = Block(dim, dim_out, groups=groups) + self.block2 = Block(dim_out, dim_out, groups=groups) + if dim != dim_out: + self.res_conv = torch.nn.Conv2d(dim, dim_out, 1) + else: + self.res_conv = torch.nn.Identity() + + def forward(self, x, mask, time_emb): + h = self.block1(x, mask) + h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1) + h = self.block2(h, mask) + output = h + self.res_conv(x * mask) + return output + + +class LinearAttention(torch.nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super(LinearAttention, self).__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', + heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', + heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class Residual(torch.nn.Module): + def __init__(self, fn): + super(Residual, self).__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + output = self.fn(x, *args, **kwargs) + x + return output + + +class SinusoidalPosEmb(torch.nn.Module): + def __init__(self, dim): + super(SinusoidalPosEmb, self).__init__() + self.dim = dim + + def forward(self, x, scale=1000): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class UNetGradTTSModel(ModelMixin, ConfigMixin): + def __init__( + self, + dim, + dim_mults=(1, 2, 4), + groups=8, + n_spks=None, + spk_emb_dim=64, + n_feats=80, + pe_scale=1000 + ): + super(UNetGradTTSModel, self).__init__() + + self.register( + dim=dim, + dim_mults=dim_mults, + groups=groups, + n_spks=n_spks, + spk_emb_dim=spk_emb_dim, + n_feats=n_feats, + pe_scale=pe_scale + ) + + self.dim = dim + self.dim_mults = dim_mults + self.groups = groups + self.n_spks = n_spks if not isinstance(n_spks, type(None)) else 1 + self.spk_emb_dim = spk_emb_dim + self.pe_scale = pe_scale + + if n_spks > 1: + self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), + torch.nn.Linear(spk_emb_dim * 4, n_feats)) + self.time_pos_emb = SinusoidalPosEmb(dim) + self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), + torch.nn.Linear(dim * 4, dim)) + + dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + self.downs = torch.nn.ModuleList([]) + self.ups = torch.nn.ModuleList([]) + num_resolutions = len(in_out) + + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + self.downs.append(torch.nn.ModuleList([ + ResnetBlock(dim_in, dim_out, time_emb_dim=dim), + ResnetBlock(dim_out, dim_out, time_emb_dim=dim), + Residual(Rezero(LinearAttention(dim_out))), + Downsample(dim_out) if not is_last else torch.nn.Identity()])) + + mid_dim = dims[-1] + self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) + self.mid_attn = Residual(Rezero(LinearAttention(mid_dim))) + self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): + self.ups.append(torch.nn.ModuleList([ + ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim), + ResnetBlock(dim_in, dim_in, time_emb_dim=dim), + Residual(Rezero(LinearAttention(dim_in))), + Upsample(dim_in)])) + self.final_block = Block(dim, dim) + self.final_conv = torch.nn.Conv2d(dim, 1, 1) + + def forward(self, x, mask, mu, t, spk=None): + if not isinstance(spk, type(None)): + s = self.spk_mlp(spk) + + t = self.time_pos_emb(t, scale=self.pe_scale) + t = self.mlp(t) + + if self.n_spks < 2: + x = torch.stack([mu, x], 1) + else: + s = s.unsqueeze(-1).repeat(1, 1, x.shape[-1]) + x = torch.stack([mu, x, s], 1) + mask = mask.unsqueeze(1) + + hiddens = [] + masks = [mask] + for resnet1, resnet2, attn, downsample in self.downs: + mask_down = masks[-1] + x = resnet1(x, mask_down, t) + x = resnet2(x, mask_down, t) + x = attn(x) + hiddens.append(x) + x = downsample(x * mask_down) + masks.append(mask_down[:, :, :, ::2]) + + masks = masks[:-1] + mask_mid = masks[-1] + x = self.mid_block1(x, mask_mid, t) + x = self.mid_attn(x) + x = self.mid_block2(x, mask_mid, t) + + for resnet1, resnet2, attn, upsample in self.ups: + mask_up = masks.pop() + x = torch.cat((x, hiddens.pop()), dim=1) + x = resnet1(x, mask_up, t) + x = resnet2(x, mask_up, t) + x = attn(x) + x = upsample(x * mask_up) + + x = self.final_block(x, mask) + output = self.final_conv(x * mask) + + return (output * mask).squeeze(1) \ No newline at end of file From 304d4d9057f6949570ffaebda60072fda4cda249 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 15 Jun 2022 11:16:24 +0200 Subject: [PATCH 11/11] begin pipeline grad tts --- src/diffusers/pipelines/pipeline_grad_tts.py | 385 +++++++++++++++++++ 1 file changed, 385 insertions(+) create mode 100644 src/diffusers/pipelines/pipeline_grad_tts.py diff --git a/src/diffusers/pipelines/pipeline_grad_tts.py b/src/diffusers/pipelines/pipeline_grad_tts.py new file mode 100644 index 0000000000..2d8f694638 --- /dev/null +++ b/src/diffusers/pipelines/pipeline_grad_tts.py @@ -0,0 +1,385 @@ +""" from https://github.com/jaywalnut310/glow-tts """ + +import math + +import torch +from torch import nn + +from diffusers.configuration_utils import ConfigMixin +from diffusers.modeling_utils import ModelMixin + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(int(max_length), dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def fix_len_compatibility(length, num_downsamplings_in_unet=2): + while True: + if length % (2**num_downsamplings_in_unet) == 0: + return length + length += 1 + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def generate_path(duration, mask): + device = duration.device + + b, t_x, t_y = mask.shape + cum_duration = torch.cumsum(duration, 1) + path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], + [1, 0], [0, 0]]))[:, :-1] + path = path * mask + return path + + +def duration_loss(logw, logw_, lengths): + loss = torch.sum((logw - logw_)**2) / torch.sum(lengths) + return loss + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-4): + super(LayerNorm, self).__init__() + self.channels = channels + self.eps = eps + + self.gamma = torch.nn.Parameter(torch.ones(channels)) + self.beta = torch.nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + n_dims = len(x.shape) + mean = torch.mean(x, 1, keepdim=True) + variance = torch.mean((x - mean)**2, 1, keepdim=True) + + x = (x - mean) * torch.rsqrt(variance + self.eps) + + shape = [1, -1] + [1] * (n_dims - 2) + x = x * self.gamma.view(*shape) + self.beta.view(*shape) + return x + + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, + n_layers, p_dropout): + super(ConvReluNorm, self).__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.conv_layers = torch.nn.ModuleList() + self.norm_layers = torch.nn.ModuleList() + self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, + kernel_size, padding=kernel_size//2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append(torch.nn.Conv1d(hidden_channels, hidden_channels, + kernel_size, padding=kernel_size//2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): + super(DurationPredictor, self).__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.p_dropout = p_dropout + + self.drop = torch.nn.Dropout(p_dropout) + self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, + kernel_size, padding=kernel_size//2) + self.norm_1 = LayerNorm(filter_channels) + self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, + kernel_size, padding=kernel_size//2) + self.norm_2 = LayerNorm(filter_channels) + self.proj = torch.nn.Conv1d(filter_channels, 1, 1) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +class MultiHeadAttention(nn.Module): + def __init__(self, channels, out_channels, n_heads, window_size=None, + heads_share=True, p_dropout=0.0, proximal_bias=False, + proximal_init=False): + super(MultiHeadAttention, self).__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.window_size = window_size + self.heads_share = heads_share + self.proximal_bias = proximal_bias + self.p_dropout = p_dropout + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = torch.nn.Conv1d(channels, channels, 1) + self.conv_k = torch.nn.Conv1d(channels, channels, 1) + self.conv_v = torch.nn.Conv1d(channels, channels, 1) + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels**-0.5 + self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel, + window_size * 2 + 1, self.k_channels) * rel_stddev) + self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel, + window_size * 2 + 1, self.k_channels) * rel_stddev) + self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) + self.drop = torch.nn.Dropout(p_dropout) + + torch.nn.init.xavier_uniform_(self.conv_q.weight) + torch.nn.init.xavier_uniform_(self.conv_k.weight) + if proximal_init: + self.conv_k.weight.data.copy_(self.conv_q.weight.data) + self.conv_k.bias.data.copy_(self.conv_q.bias.data) + torch.nn.init.xavier_uniform_(self.conv_v.weight) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) + if self.window_size is not None: + assert t_s == t_t, "Relative attention is only available for self-attention." + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings) + rel_logits = self._relative_position_to_absolute_position(rel_logits) + scores_local = rel_logits / math.sqrt(self.k_channels) + scores = scores + scores_local + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, + dtype=scores.dtype) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + p_attn = torch.nn.functional.softmax(scores, dim=-1) + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values(relative_weights, + value_relative_embeddings) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x, y): + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret + + def _get_relative_embeddings(self, relative_embeddings, length): + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = torch.nn.functional.pad( + relative_embeddings, convert_pad_shape([[0, 0], + [pad_length, pad_length], [0, 0]])) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[:, + slice_start_position:slice_end_position] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x): + batch, heads, length, _ = x.size() + x = torch.nn.functional.pad(x, convert_pad_shape([[0,0],[0,0],[0,0],[0,1]])) + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0,0],[0,0],[0,length-1]])) + x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:] + return x_final + + def _absolute_position_to_relative_position(self, x): + batch, heads, length, _ = x.size() + x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]])) + x_flat = x.view([batch, heads, length**2 + length*(length - 1)]) + x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]])) + x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:] + return x_final + + def _attention_bias_proximal(self, length): + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__(self, in_channels, out_channels, filter_channels, kernel_size, + p_dropout=0.0): + super(FFN, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, + padding=kernel_size//2) + self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, + padding=kernel_size//2) + self.drop = torch.nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + return x * x_mask + + +class Encoder(nn.Module): + def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, + kernel_size=1, p_dropout=0.0, window_size=None, **kwargs): + super(Encoder, self).__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + + self.drop = torch.nn.Dropout(p_dropout) + self.attn_layers = torch.nn.ModuleList() + self.norm_layers_1 = torch.nn.ModuleList() + self.ffn_layers = torch.nn.ModuleList() + self.norm_layers_2 = torch.nn.ModuleList() + for _ in range(self.n_layers): + self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, + n_heads, window_size=window_size, p_dropout=p_dropout)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append(FFN(hidden_channels, hidden_channels, + filter_channels, kernel_size, p_dropout=p_dropout)) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + for i in range(self.n_layers): + x = x * x_mask + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class TextEncoder(ModelMixin, ConfigMixin): + def __init__(self, n_vocab, n_feats, n_channels, filter_channels, + filter_channels_dp, n_heads, n_layers, kernel_size, + p_dropout, window_size=None, spk_emb_dim=64, n_spks=1): + super(TextEncoder, self).__init__() + + self.register( + n_vocab=n_vocab, + n_feats=n_feats, + n_channels=n_channels, + filter_channels=filter_channels, + filter_channels_dp=filter_channels_dp, + n_heads=n_heads, + n_layers=n_layers, + kernel_size=kernel_size, + p_dropout=p_dropout, + window_size=window_size, + spk_emb_dim=spk_emb_dim, + n_spks=n_spks + ) + + + self.n_vocab = n_vocab + self.n_feats = n_feats + self.n_channels = n_channels + self.filter_channels = filter_channels + self.filter_channels_dp = filter_channels_dp + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + self.spk_emb_dim = spk_emb_dim + self.n_spks = n_spks + + self.emb = torch.nn.Embedding(n_vocab, n_channels) + torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5) + + self.prenet = ConvReluNorm(n_channels, n_channels, n_channels, + kernel_size=5, n_layers=3, p_dropout=0.5) + + self.encoder = Encoder(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels, n_heads, n_layers, + kernel_size, p_dropout, window_size=window_size) + + self.proj_m = torch.nn.Conv1d(n_channels + (spk_emb_dim if n_spks > 1 else 0), n_feats, 1) + self.proj_w = DurationPredictor(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp, + kernel_size, p_dropout) + + def forward(self, x, x_lengths, spk=None): + x = self.emb(x) * math.sqrt(self.n_channels) + x = torch.transpose(x, 1, -1) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + + x = self.prenet(x, x_mask) + if self.n_spks > 1: + x = torch.cat([x, spk.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1) + x = self.encoder(x, x_mask) + mu = self.proj_m(x) * x_mask + + x_dp = torch.detach(x) + logw = self.proj_w(x_dp, x_mask) + + return mu, logw, x_mask