From c352faeae3fe14414324750f79985ef16d08e821 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 4 Jul 2022 15:06:00 +0200 Subject: [PATCH] Add MidBlock to Grad-TTS (#74) Finish --- src/diffusers/models/unet_grad_tts.py | 22 +++++++++++++++----- src/diffusers/models/unet_new.py | 29 +++++++++++++++------------ 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index 357d678495..a3934ba80c 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -5,6 +5,7 @@ from ..modeling_utils import ModelMixin from .attention import LinearAttention from .embeddings import get_timestep_embedding from .resnet import Downsample2D, ResnetBlock2D, Upsample2D +from .unet_new import UNetMidBlock2D class Mish(torch.nn.Module): @@ -111,6 +112,17 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ) mid_dim = dims[-1] + + self.mid = UNetMidBlock2D( + in_channels=mid_dim, + temb_channels=dim, + resnet_groups=8, + resnet_pre_norm=False, + resnet_eps=1e-5, + resnet_act_fn="mish", + attention_layer_type="linear", + ) + self.mid_block1 = ResnetBlock2D( in_channels=mid_dim, out_channels=mid_dim, @@ -132,8 +144,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): non_linearity="mish", overwrite_for_grad_tts=True, ) - - # self.mid = UNetMidBlock2D + self.mid.resnet_1 = self.mid_block1 + self.mid.attn = self.mid_attn + self.mid.resnet_2 = self.mid_block2 for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): self.ups.append( @@ -198,9 +211,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): masks = masks[:-1] mask_mid = masks[-1] - x = self.mid_block1(x, t, mask_mid) - x = self.mid_attn(x) - x = self.mid_block2(x, t, mask_mid) + + x = self.mid(x, t, mask=mask_mid) for resnet1, resnet2, attn, upsample in self.ups: mask_up = masks.pop() diff --git a/src/diffusers/models/unet_new.py b/src/diffusers/models/unet_new.py index 066adb6a61..66d59bc60a 100644 --- a/src/diffusers/models/unet_new.py +++ b/src/diffusers/models/unet_new.py @@ -14,7 +14,7 @@ # limitations under the License. from torch import nn -from .attention import AttentionBlock, SpatialTransformer +from .attention import AttentionBlock, LinearAttention, SpatialTransformer from .resnet import ResnetBlock2D @@ -23,11 +23,12 @@ class UNetMidBlock2D(nn.Module): self, in_channels: int, temb_channels: int, - dropout: float, + dropout: float = 0.0, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, + resnet_pre_norm: bool = True, attention_layer_type: str = "self", attn_num_heads=1, attn_num_head_channels=None, @@ -49,6 +50,7 @@ class UNetMidBlock2D(nn.Module): time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, ) if attention_layer_type == "self": @@ -61,15 +63,14 @@ class UNetMidBlock2D(nn.Module): rescale_output_factor=output_scale_factor, ) elif attention_layer_type == "spatial": - self.attn = ( - SpatialTransformer( - in_channels, - attn_num_heads, - attn_num_head_channels, - depth=attn_depth, - context_dim=attn_encoder_channels, - ), + self.attn = SpatialTransformer( + attn_num_heads, + attn_num_head_channels, + depth=attn_depth, + context_dim=attn_encoder_channels, ) + elif attention_layer_type == "linear": + self.attn = LinearAttention(in_channels) self.resnet_2 = ResnetBlock2D( in_channels=in_channels, @@ -80,6 +81,7 @@ class UNetMidBlock2D(nn.Module): time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, ) # TODO(Patrick) - delete all of the following code @@ -110,19 +112,20 @@ class UNetMidBlock2D(nn.Module): eps=resnet_eps, ) - def forward(self, hidden_states, temb=None, encoder_states=None): + def forward(self, hidden_states, temb=None, encoder_states=None, mask=1.0): if not self.is_overwritten and self.overwrite_unet: self.resnet_1 = self.block_1 self.attn = self.attn_1 self.resnet_2 = self.block_2 self.is_overwritten = True - hidden_states = self.resnet_1(hidden_states, temb) + hidden_states = self.resnet_1(hidden_states, temb, mask=mask) if encoder_states is None: hidden_states = self.attn(hidden_states) else: hidden_states = self.attn(hidden_states, encoder_states) - hidden_states = self.resnet_2(hidden_states, temb) + hidden_states = self.resnet_2(hidden_states, temb, mask=mask) + return hidden_states