From 466214d2d66429be663dc8405a35cdef82d1e5f6 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 29 Jun 2022 14:29:35 +0000 Subject: [PATCH] Remove bogus file --- src/diffusers/models/unet_grad_tts.py | 26 +++++++------------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index 432199ccc5..56614f8d9d 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -5,8 +5,7 @@ from ..modeling_utils import ModelMixin from .attention import LinearAttention from .embeddings import get_timestep_embedding from .resnet import Downsample -from .resnet import ResnetBlock as ResnetBlockNew -from .resnet import ResnetBlockGradTTS as ResnetBlock +from .resnet import ResnetBlock from .resnet import Upsample @@ -82,20 +81,13 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): self.ups = torch.nn.ModuleList([]) num_resolutions = len(in_out) -# num_groups = 8 -# self.pre_norm = False -# eps = 1e-5 -# non_linearity = "mish" - for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) self.downs.append( torch.nn.ModuleList( [ -# ResnetBlock(dim_in, dim_out, time_emb_dim=dim), -# ResnetBlock(dim_out, dim_out, time_emb_dim=dim), - ResnetBlockNew(in_channels=dim_in, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True), - ResnetBlockNew(in_channels=dim_out, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True), + ResnetBlock(in_channels=dim_in, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True), + ResnetBlock(in_channels=dim_out, out_channels=dim_out, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True), Residual(Rezero(LinearAttention(dim_out))), Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(), ] @@ -103,20 +95,16 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ) mid_dim = dims[-1] -# self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) -# self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) - self.mid_block1 = ResnetBlockNew(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True) + self.mid_block1 = ResnetBlock(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True) self.mid_attn = Residual(Rezero(LinearAttention(mid_dim))) - self.mid_block2 = ResnetBlockNew(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True) + self.mid_block2 = ResnetBlock(in_channels=mid_dim, out_channels=mid_dim, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): self.ups.append( torch.nn.ModuleList( [ -# ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim), -# ResnetBlock(dim_in, dim_in, time_emb_dim=dim), - ResnetBlockNew(in_channels=dim_out * 2, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True), - ResnetBlockNew(in_channels=dim_in, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True), + ResnetBlock(in_channels=dim_out * 2, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True), + ResnetBlock(in_channels=dim_in, out_channels=dim_in, temb_channels=dim, groups=8, pre_norm=False, eps=1e-5, non_linearity="mish", overwrite_for_grad_tts=True), Residual(Rezero(LinearAttention(dim_in))), Upsample(dim_in, use_conv_transpose=True), ]