1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[MidBlock] Fix mid block (#78)

* upload files

* finish
This commit is contained in:
Patrick von Platen
2022-07-05 15:05:41 +02:00
committed by GitHub
parent c352faeae3
commit ea8d58ea91
7 changed files with 114 additions and 101 deletions

View File

@@ -17,7 +17,7 @@ class LinearAttention(torch.nn.Module):
self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
def forward(self, x, encoder_states=None):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = (

View File

@@ -106,9 +106,20 @@ class UNetModel(ModelMixin, ConfigMixin):
self.down.append(down)
# middle
self.mid = UNetMidBlock2D(
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True)
self.mid.block_2 = ResnetBlock2D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.mid_new = UNetMidBlock2D(
in_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, overwrite_qkv=True, overwrite_unet=True
)
self.mid_new.resnets[0] = self.mid.block_1
self.mid_new.attentions[0] = self.mid.attn_1
self.mid_new.resnets[1] = self.mid.block_2
# upsampling
self.up = nn.ModuleList()
@@ -167,10 +178,7 @@ class UNetModel(ModelMixin, ConfigMixin):
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = self.mid(hs[-1], temb)
# h = self.mid.block_1(h, temb)
# h = self.mid.attn_1(h)
# h = self.mid.block_2(h, temb)
h = self.mid_new(hs[-1], temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):

View File

@@ -266,9 +266,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
overwrite_for_glide=True,
),
)
self.mid.resnet_1 = self.middle_block[0]
self.mid.attn = self.middle_block[1]
self.mid.resnet_2 = self.middle_block[2]
self.mid.resnets[0] = self.middle_block[0]
self.mid.attentions[0] = self.middle_block[1]
self.mid.resnets[1] = self.middle_block[2]
self._feature_size += ch
@@ -542,7 +542,7 @@ class GlideSuperResUNetModel(GlideUNetModel):
for module in self.input_blocks:
h = module(h, emb)
hs.append(h)
h = self.middle_block(h, emb)
h = self.mid(h, emb)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb)

View File

@@ -19,8 +19,8 @@ class Rezero(torch.nn.Module):
self.fn = fn
self.g = torch.nn.Parameter(torch.zeros(1))
def forward(self, x):
return self.fn(x) * self.g
def forward(self, x, encoder_out=None):
return self.fn(x, encoder_out) * self.g
class Block(torch.nn.Module):
@@ -144,9 +144,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
non_linearity="mish",
overwrite_for_grad_tts=True,
)
self.mid.resnet_1 = self.mid_block1
self.mid.attn = self.mid_attn
self.mid.resnet_2 = self.mid_block2
self.mid.resnets[0] = self.mid_block1
self.mid.attentions[0] = self.mid_attn
self.mid.resnets[1] = self.mid_block2
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
self.ups.append(

View File

@@ -408,9 +408,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
overwrite_for_ldm=True,
),
)
self.mid.resnet_1 = self.middle_block[0]
self.mid.attn = self.middle_block[1]
self.mid.resnet_2 = self.middle_block[2]
self.mid.resnets[0] = self.middle_block[0]
self.mid.attentions[0] = self.middle_block[1]
self.mid.resnets[1] = self.middle_block[2]
self._feature_size += ch

View File

@@ -24,6 +24,7 @@ class UNetMidBlock2D(nn.Module):
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_blocks: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
@@ -41,91 +42,95 @@ class UNetMidBlock2D(nn.Module):
):
super().__init__()
self.resnet_1 = ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
if attention_layer_type == "self":
self.attn = AttentionBlock(
in_channels,
num_heads=attn_num_heads,
num_head_channels=attn_num_head_channels,
encoder_channels=attn_encoder_channels,
overwrite_qkv=overwrite_qkv,
rescale_output_factor=output_scale_factor,
)
elif attention_layer_type == "spatial":
self.attn = SpatialTransformer(
attn_num_heads,
attn_num_head_channels,
depth=attn_depth,
context_dim=attn_encoder_channels,
)
elif attention_layer_type == "linear":
self.attn = LinearAttention(in_channels)
self.resnet_2 = ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
# TODO(Patrick) - delete all of the following code
self.is_overwritten = False
self.overwrite_unet = overwrite_unet
if self.overwrite_unet:
block_in = in_channels
self.temb_ch = temb_channels
self.block_1 = ResnetBlock2D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
# there is always at least one resnet
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
groups=resnet_groups,
dropout=dropout,
eps=resnet_eps,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
self.attn_1 = AttentionBlock(
block_in,
num_heads=attn_num_heads,
num_head_channels=attn_num_head_channels,
encoder_channels=attn_encoder_channels,
overwrite_qkv=True,
)
self.block_2 = ResnetBlock2D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
eps=resnet_eps,
]
attentions = []
for _ in range(num_blocks):
if attention_layer_type == "self":
attentions.append(
AttentionBlock(
in_channels,
num_heads=attn_num_heads,
num_head_channels=attn_num_head_channels,
encoder_channels=attn_encoder_channels,
overwrite_qkv=overwrite_qkv,
rescale_output_factor=output_scale_factor,
)
)
elif attention_layer_type == "spatial":
attentions.append(
SpatialTransformer(
in_channels,
attn_num_heads,
attn_num_head_channels,
depth=attn_depth,
context_dim=attn_encoder_channels,
)
)
elif attention_layer_type == "linear":
attentions.append(LinearAttention(in_channels))
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, encoder_states=None, mask=1.0):
if not self.is_overwritten and self.overwrite_unet:
self.resnet_1 = self.block_1
self.attn = self.attn_1
self.resnet_2 = self.block_2
self.is_overwritten = True
hidden_states = self.resnets[0](hidden_states, temb, mask=mask)
hidden_states = self.resnet_1(hidden_states, temb, mask=mask)
if encoder_states is None:
hidden_states = self.attn(hidden_states)
else:
hidden_states = self.attn(hidden_states, encoder_states)
hidden_states = self.resnet_2(hidden_states, temb, mask=mask)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_states)
hidden_states = resnet(hidden_states, temb, mask=mask)
return hidden_states
# class UNetResAttnDownBlock(nn.Module):
# def __init__(
# self,
# in_channels: int,
# out_channels: int,
# temb_channels: int,
# dropout: float = 0.0,
# resnet_eps: float = 1e-6,
# resnet_time_scale_shift: str = "default",
# resnet_act_fn: str = "swish",
# resnet_groups: int = 32,
# resnet_pre_norm: bool = True,
# attention_layer_type: str = "self",
# attn_num_heads=1,
# attn_num_head_channels=None,
# attn_encoder_channels=None,
# attn_dim_head=None,
# attn_depth=None,
# output_scale_factor=1.0,
# overwrite_qkv=False,
# overwrite_unet=False,
# ):
#
# self.resents =

View File

@@ -249,9 +249,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
overwrite_for_score_vde=True,
)
)
self.mid.resnet_1 = modules[len(modules) - 3]
self.mid.attn = modules[len(modules) - 2]
self.mid.resnet_2 = modules[len(modules) - 1]
self.mid.resnets[0] = modules[len(modules) - 3]
self.mid.attentions[0] = modules[len(modules) - 2]
self.mid.resnets[1] = modules[len(modules) - 1]
pyramid_ch = 0
# Upsampling block