From ea8d58ea9186d3298c091de177fe9332199ac397 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 5 Jul 2022 15:05:41 +0200 Subject: [PATCH] [MidBlock] Fix mid block (#78) * upload files * finish --- src/diffusers/models/attention.py | 2 +- src/diffusers/models/unet.py | 18 +- src/diffusers/models/unet_glide.py | 8 +- src/diffusers/models/unet_grad_tts.py | 10 +- src/diffusers/models/unet_ldm.py | 6 +- src/diffusers/models/unet_new.py | 165 +++++++++--------- .../models/unet_sde_score_estimation.py | 6 +- 7 files changed, 114 insertions(+), 101 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 6daca7f0f0..395ac5579d 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -17,7 +17,7 @@ class LinearAttention(torch.nn.Module): self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1) - def forward(self, x): + def forward(self, x, encoder_states=None): b, c, h, w = x.shape qkv = self.to_qkv(x) q, k, v = ( diff --git a/src/diffusers/models/unet.py b/src/diffusers/models/unet.py index 563f4f2084..2eab5f3e8c 100644 --- a/src/diffusers/models/unet.py +++ b/src/diffusers/models/unet.py @@ -106,9 +106,20 @@ class UNetModel(ModelMixin, ConfigMixin): self.down.append(down) # middle - self.mid = UNetMidBlock2D( + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock2D( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True) + self.mid.block_2 = ResnetBlock2D( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + self.mid_new = UNetMidBlock2D( in_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, overwrite_qkv=True, overwrite_unet=True ) + self.mid_new.resnets[0] = self.mid.block_1 + self.mid_new.attentions[0] = self.mid.attn_1 + self.mid_new.resnets[1] = self.mid.block_2 # upsampling self.up = nn.ModuleList() @@ -167,10 +178,7 @@ class UNetModel(ModelMixin, ConfigMixin): hs.append(self.down[i_level].downsample(hs[-1])) # middle - h = self.mid(hs[-1], temb) - # h = self.mid.block_1(h, temb) - # h = self.mid.attn_1(h) - # h = self.mid.block_2(h, temb) + h = self.mid_new(hs[-1], temb) # upsampling for i_level in reversed(range(self.num_resolutions)): diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index 7dca03b63c..015a0b1855 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -266,9 +266,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin): overwrite_for_glide=True, ), ) - self.mid.resnet_1 = self.middle_block[0] - self.mid.attn = self.middle_block[1] - self.mid.resnet_2 = self.middle_block[2] + self.mid.resnets[0] = self.middle_block[0] + self.mid.attentions[0] = self.middle_block[1] + self.mid.resnets[1] = self.middle_block[2] self._feature_size += ch @@ -542,7 +542,7 @@ class GlideSuperResUNetModel(GlideUNetModel): for module in self.input_blocks: h = module(h, emb) hs.append(h) - h = self.middle_block(h, emb) + h = self.mid(h, emb) for module in self.output_blocks: h = torch.cat([h, hs.pop()], dim=1) h = module(h, emb) diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index a3934ba80c..e9691a2f94 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -19,8 +19,8 @@ class Rezero(torch.nn.Module): self.fn = fn self.g = torch.nn.Parameter(torch.zeros(1)) - def forward(self, x): - return self.fn(x) * self.g + def forward(self, x, encoder_out=None): + return self.fn(x, encoder_out) * self.g class Block(torch.nn.Module): @@ -144,9 +144,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): non_linearity="mish", overwrite_for_grad_tts=True, ) - self.mid.resnet_1 = self.mid_block1 - self.mid.attn = self.mid_attn - self.mid.resnet_2 = self.mid_block2 + self.mid.resnets[0] = self.mid_block1 + self.mid.attentions[0] = self.mid_attn + self.mid.resnets[1] = self.mid_block2 for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): self.ups.append( diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index a46a4c1848..9e74b1cf12 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -408,9 +408,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin): overwrite_for_ldm=True, ), ) - self.mid.resnet_1 = self.middle_block[0] - self.mid.attn = self.middle_block[1] - self.mid.resnet_2 = self.middle_block[2] + self.mid.resnets[0] = self.middle_block[0] + self.mid.attentions[0] = self.middle_block[1] + self.mid.resnets[1] = self.middle_block[2] self._feature_size += ch diff --git a/src/diffusers/models/unet_new.py b/src/diffusers/models/unet_new.py index 66d59bc60a..1fa3187ea5 100644 --- a/src/diffusers/models/unet_new.py +++ b/src/diffusers/models/unet_new.py @@ -24,6 +24,7 @@ class UNetMidBlock2D(nn.Module): in_channels: int, temb_channels: int, dropout: float = 0.0, + num_blocks: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", @@ -41,91 +42,95 @@ class UNetMidBlock2D(nn.Module): ): super().__init__() - self.resnet_1 = ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - - if attention_layer_type == "self": - self.attn = AttentionBlock( - in_channels, - num_heads=attn_num_heads, - num_head_channels=attn_num_head_channels, - encoder_channels=attn_encoder_channels, - overwrite_qkv=overwrite_qkv, - rescale_output_factor=output_scale_factor, - ) - elif attention_layer_type == "spatial": - self.attn = SpatialTransformer( - attn_num_heads, - attn_num_head_channels, - depth=attn_depth, - context_dim=attn_encoder_channels, - ) - elif attention_layer_type == "linear": - self.attn = LinearAttention(in_channels) - - self.resnet_2 = ResnetBlock2D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - - # TODO(Patrick) - delete all of the following code - self.is_overwritten = False - self.overwrite_unet = overwrite_unet - if self.overwrite_unet: - block_in = in_channels - self.temb_ch = temb_channels - self.block_1 = ResnetBlock2D( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + groups=resnet_groups, dropout=dropout, - eps=resnet_eps, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, ) - self.attn_1 = AttentionBlock( - block_in, - num_heads=attn_num_heads, - num_head_channels=attn_num_head_channels, - encoder_channels=attn_encoder_channels, - overwrite_qkv=True, - ) - self.block_2 = ResnetBlock2D( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - eps=resnet_eps, + ] + attentions = [] + + for _ in range(num_blocks): + if attention_layer_type == "self": + attentions.append( + AttentionBlock( + in_channels, + num_heads=attn_num_heads, + num_head_channels=attn_num_head_channels, + encoder_channels=attn_encoder_channels, + overwrite_qkv=overwrite_qkv, + rescale_output_factor=output_scale_factor, + ) + ) + elif attention_layer_type == "spatial": + attentions.append( + SpatialTransformer( + in_channels, + attn_num_heads, + attn_num_head_channels, + depth=attn_depth, + context_dim=attn_encoder_channels, + ) + ) + elif attention_layer_type == "linear": + attentions.append(LinearAttention(in_channels)) + + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + def forward(self, hidden_states, temb=None, encoder_states=None, mask=1.0): - if not self.is_overwritten and self.overwrite_unet: - self.resnet_1 = self.block_1 - self.attn = self.attn_1 - self.resnet_2 = self.block_2 - self.is_overwritten = True + hidden_states = self.resnets[0](hidden_states, temb, mask=mask) - hidden_states = self.resnet_1(hidden_states, temb, mask=mask) - - if encoder_states is None: - hidden_states = self.attn(hidden_states) - else: - hidden_states = self.attn(hidden_states, encoder_states) - - hidden_states = self.resnet_2(hidden_states, temb, mask=mask) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_states) + hidden_states = resnet(hidden_states, temb, mask=mask) return hidden_states + + +# class UNetResAttnDownBlock(nn.Module): +# def __init__( +# self, +# in_channels: int, +# out_channels: int, +# temb_channels: int, +# dropout: float = 0.0, +# resnet_eps: float = 1e-6, +# resnet_time_scale_shift: str = "default", +# resnet_act_fn: str = "swish", +# resnet_groups: int = 32, +# resnet_pre_norm: bool = True, +# attention_layer_type: str = "self", +# attn_num_heads=1, +# attn_num_head_channels=None, +# attn_encoder_channels=None, +# attn_dim_head=None, +# attn_depth=None, +# output_scale_factor=1.0, +# overwrite_qkv=False, +# overwrite_unet=False, +# ): +# +# self.resents = diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index cdf6c6114f..f9de0e4a10 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -249,9 +249,9 @@ class NCSNpp(ModelMixin, ConfigMixin): overwrite_for_score_vde=True, ) ) - self.mid.resnet_1 = modules[len(modules) - 3] - self.mid.attn = modules[len(modules) - 2] - self.mid.resnet_2 = modules[len(modules) - 1] + self.mid.resnets[0] = modules[len(modules) - 3] + self.mid.attentions[0] = modules[len(modules) - 2] + self.mid.resnets[1] = modules[len(modules) - 1] pyramid_ch = 0 # Upsampling block