mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
committed by
GitHub
parent
c352faeae3
commit
ea8d58ea91
@@ -17,7 +17,7 @@ class LinearAttention(torch.nn.Module):
|
||||
self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
||||
self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, encoder_states=None):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = (
|
||||
|
||||
@@ -106,9 +106,20 @@ class UNetModel(ModelMixin, ConfigMixin):
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = UNetMidBlock2D(
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock2D(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
self.mid.attn_1 = AttentionBlock(block_in, overwrite_qkv=True)
|
||||
self.mid.block_2 = ResnetBlock2D(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
self.mid_new = UNetMidBlock2D(
|
||||
in_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, overwrite_qkv=True, overwrite_unet=True
|
||||
)
|
||||
self.mid_new.resnets[0] = self.mid.block_1
|
||||
self.mid_new.attentions[0] = self.mid.attn_1
|
||||
self.mid_new.resnets[1] = self.mid.block_2
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
@@ -167,10 +178,7 @@ class UNetModel(ModelMixin, ConfigMixin):
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = self.mid(hs[-1], temb)
|
||||
# h = self.mid.block_1(h, temb)
|
||||
# h = self.mid.attn_1(h)
|
||||
# h = self.mid.block_2(h, temb)
|
||||
h = self.mid_new(hs[-1], temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
|
||||
@@ -266,9 +266,9 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
|
||||
overwrite_for_glide=True,
|
||||
),
|
||||
)
|
||||
self.mid.resnet_1 = self.middle_block[0]
|
||||
self.mid.attn = self.middle_block[1]
|
||||
self.mid.resnet_2 = self.middle_block[2]
|
||||
self.mid.resnets[0] = self.middle_block[0]
|
||||
self.mid.attentions[0] = self.middle_block[1]
|
||||
self.mid.resnets[1] = self.middle_block[2]
|
||||
|
||||
self._feature_size += ch
|
||||
|
||||
@@ -542,7 +542,7 @@ class GlideSuperResUNetModel(GlideUNetModel):
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb)
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb)
|
||||
h = self.mid(h, emb)
|
||||
for module in self.output_blocks:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb)
|
||||
|
||||
@@ -19,8 +19,8 @@ class Rezero(torch.nn.Module):
|
||||
self.fn = fn
|
||||
self.g = torch.nn.Parameter(torch.zeros(1))
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(x) * self.g
|
||||
def forward(self, x, encoder_out=None):
|
||||
return self.fn(x, encoder_out) * self.g
|
||||
|
||||
|
||||
class Block(torch.nn.Module):
|
||||
@@ -144,9 +144,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||
non_linearity="mish",
|
||||
overwrite_for_grad_tts=True,
|
||||
)
|
||||
self.mid.resnet_1 = self.mid_block1
|
||||
self.mid.attn = self.mid_attn
|
||||
self.mid.resnet_2 = self.mid_block2
|
||||
self.mid.resnets[0] = self.mid_block1
|
||||
self.mid.attentions[0] = self.mid_attn
|
||||
self.mid.resnets[1] = self.mid_block2
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
self.ups.append(
|
||||
|
||||
@@ -408,9 +408,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
|
||||
overwrite_for_ldm=True,
|
||||
),
|
||||
)
|
||||
self.mid.resnet_1 = self.middle_block[0]
|
||||
self.mid.attn = self.middle_block[1]
|
||||
self.mid.resnet_2 = self.middle_block[2]
|
||||
self.mid.resnets[0] = self.middle_block[0]
|
||||
self.mid.attentions[0] = self.middle_block[1]
|
||||
self.mid.resnets[1] = self.middle_block[2]
|
||||
|
||||
self._feature_size += ch
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ class UNetMidBlock2D(nn.Module):
|
||||
in_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_blocks: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
@@ -41,91 +42,95 @@ class UNetMidBlock2D(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.resnet_1 = ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
|
||||
if attention_layer_type == "self":
|
||||
self.attn = AttentionBlock(
|
||||
in_channels,
|
||||
num_heads=attn_num_heads,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
encoder_channels=attn_encoder_channels,
|
||||
overwrite_qkv=overwrite_qkv,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
)
|
||||
elif attention_layer_type == "spatial":
|
||||
self.attn = SpatialTransformer(
|
||||
attn_num_heads,
|
||||
attn_num_head_channels,
|
||||
depth=attn_depth,
|
||||
context_dim=attn_encoder_channels,
|
||||
)
|
||||
elif attention_layer_type == "linear":
|
||||
self.attn = LinearAttention(in_channels)
|
||||
|
||||
self.resnet_2 = ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
|
||||
# TODO(Patrick) - delete all of the following code
|
||||
self.is_overwritten = False
|
||||
self.overwrite_unet = overwrite_unet
|
||||
if self.overwrite_unet:
|
||||
block_in = in_channels
|
||||
self.temb_ch = temb_channels
|
||||
self.block_1 = ResnetBlock2D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
eps=resnet_eps,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
self.attn_1 = AttentionBlock(
|
||||
block_in,
|
||||
num_heads=attn_num_heads,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
encoder_channels=attn_encoder_channels,
|
||||
overwrite_qkv=True,
|
||||
)
|
||||
self.block_2 = ResnetBlock2D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
eps=resnet_eps,
|
||||
]
|
||||
attentions = []
|
||||
|
||||
for _ in range(num_blocks):
|
||||
if attention_layer_type == "self":
|
||||
attentions.append(
|
||||
AttentionBlock(
|
||||
in_channels,
|
||||
num_heads=attn_num_heads,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
encoder_channels=attn_encoder_channels,
|
||||
overwrite_qkv=overwrite_qkv,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
)
|
||||
)
|
||||
elif attention_layer_type == "spatial":
|
||||
attentions.append(
|
||||
SpatialTransformer(
|
||||
in_channels,
|
||||
attn_num_heads,
|
||||
attn_num_head_channels,
|
||||
depth=attn_depth,
|
||||
context_dim=attn_encoder_channels,
|
||||
)
|
||||
)
|
||||
elif attention_layer_type == "linear":
|
||||
attentions.append(LinearAttention(in_channels))
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_states=None, mask=1.0):
|
||||
if not self.is_overwritten and self.overwrite_unet:
|
||||
self.resnet_1 = self.block_1
|
||||
self.attn = self.attn_1
|
||||
self.resnet_2 = self.block_2
|
||||
self.is_overwritten = True
|
||||
hidden_states = self.resnets[0](hidden_states, temb, mask=mask)
|
||||
|
||||
hidden_states = self.resnet_1(hidden_states, temb, mask=mask)
|
||||
|
||||
if encoder_states is None:
|
||||
hidden_states = self.attn(hidden_states)
|
||||
else:
|
||||
hidden_states = self.attn(hidden_states, encoder_states)
|
||||
|
||||
hidden_states = self.resnet_2(hidden_states, temb, mask=mask)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
hidden_states = attn(hidden_states, encoder_states)
|
||||
hidden_states = resnet(hidden_states, temb, mask=mask)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# class UNetResAttnDownBlock(nn.Module):
|
||||
# def __init__(
|
||||
# self,
|
||||
# in_channels: int,
|
||||
# out_channels: int,
|
||||
# temb_channels: int,
|
||||
# dropout: float = 0.0,
|
||||
# resnet_eps: float = 1e-6,
|
||||
# resnet_time_scale_shift: str = "default",
|
||||
# resnet_act_fn: str = "swish",
|
||||
# resnet_groups: int = 32,
|
||||
# resnet_pre_norm: bool = True,
|
||||
# attention_layer_type: str = "self",
|
||||
# attn_num_heads=1,
|
||||
# attn_num_head_channels=None,
|
||||
# attn_encoder_channels=None,
|
||||
# attn_dim_head=None,
|
||||
# attn_depth=None,
|
||||
# output_scale_factor=1.0,
|
||||
# overwrite_qkv=False,
|
||||
# overwrite_unet=False,
|
||||
# ):
|
||||
#
|
||||
# self.resents =
|
||||
|
||||
@@ -249,9 +249,9 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
overwrite_for_score_vde=True,
|
||||
)
|
||||
)
|
||||
self.mid.resnet_1 = modules[len(modules) - 3]
|
||||
self.mid.attn = modules[len(modules) - 2]
|
||||
self.mid.resnet_2 = modules[len(modules) - 1]
|
||||
self.mid.resnets[0] = modules[len(modules) - 3]
|
||||
self.mid.attentions[0] = modules[len(modules) - 2]
|
||||
self.mid.resnets[1] = modules[len(modules) - 1]
|
||||
|
||||
pyramid_ch = 0
|
||||
# Upsampling block
|
||||
|
||||
Reference in New Issue
Block a user