diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 0b5262907a..9a7eaa2ecd 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -46,8 +46,8 @@ def conv_transpose_nd(dims, *args, **kwargs): raise ValueError(f"unsupported dimensions: {dims}") -def Normalize(in_channels): - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) +def Normalize(in_channels, num_groups=32, eps=1e-6): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=eps, affine=True) def nonlinearity(x, swish=1.0): @@ -166,8 +166,8 @@ class Downsample(nn.Module): # # class GlideUpsample(nn.Module): # """ -# An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param -# use_conv: a bool determining if a convolution is # applied. :param dims: determines if the signal is 1D, 2D, or 3D. If +# An upsampling layer with an optional convolution. # # :param channels: channels in the inputs and outputs. :param # +use_conv: a bool determining if a convolution is # applied. :param dims: determines if the signal is 1D, 2D, or 3D. If # 3D, then # upsampling occurs in the inner-two dimensions. #""" # # def __init__(self, channels, use_conv, dims=2, out_channels=None): @@ -192,8 +192,8 @@ class Downsample(nn.Module): # # class LDMUpsample(nn.Module): # """ -# An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param # -# use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. # If +# An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param # # +use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. # If # 3D, then # upsampling occurs in the inner-two dimensions. #""" # # def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): @@ -340,40 +340,118 @@ class ResBlock(TimestepBlock): return self.skip_connection(x) + h -# unet.py +# unet.py and unet_grad_tts.py class ResnetBlock(nn.Module): - def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + overwrite_for_grad_tts=False, + ): super().__init__() + self.pre_norm = pre_norm self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut - self.norm1 = Normalize(in_channels) + if self.pre_norm: + self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps) + else: + self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.temb_proj = torch.nn.Linear(temb_channels, out_channels) - self.norm2 = Normalize(out_channels) + self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps) self.dropout = torch.nn.Dropout(dropout) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if non_linearity == "swish": + self.nonlinearity = nonlinearity + elif non_linearity == "mish": + self.nonlinearity = Mish() + if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) else: self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - def forward(self, x, temb): + self.is_overwritten = False + self.overwrite_for_grad_tts = overwrite_for_grad_tts + if self.overwrite_for_grad_tts: + dim = in_channels + dim_out = out_channels + time_emb_dim = temb_channels + self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out)) + self.pre_norm = pre_norm + + self.block1 = Block(dim, dim_out, groups=groups) + self.block2 = Block(dim_out, dim_out, groups=groups) + if dim != dim_out: + self.res_conv = torch.nn.Conv2d(dim, dim_out, 1) + else: + self.res_conv = torch.nn.Identity() + + def set_weights_grad_tts(self): + self.conv1.weight.data = self.block1.block[0].weight.data + self.conv1.bias.data = self.block1.block[0].bias.data + self.norm1.weight.data = self.block1.block[1].weight.data + self.norm1.bias.data = self.block1.block[1].bias.data + + self.conv2.weight.data = self.block2.block[0].weight.data + self.conv2.bias.data = self.block2.block[0].bias.data + self.norm2.weight.data = self.block2.block[1].weight.data + self.norm2.bias.data = self.block2.block[1].bias.data + + self.temb_proj.weight.data = self.mlp[1].weight.data + self.temb_proj.bias.data = self.mlp[1].bias.data + + if self.in_channels != self.out_channels: + self.nin_shortcut.weight.data = self.res_conv.weight.data + self.nin_shortcut.bias.data = self.res_conv.bias.data + + def forward(self, x, temb, mask=None): + if self.overwrite_for_grad_tts and not self.is_overwritten: + self.set_weights_grad_tts() + self.is_overwritten = True + h = x - h = self.norm1(h) - h = nonlinearity(h) + h = h * mask if mask is not None else h + if self.pre_norm: + h = self.norm1(h) + h = self.nonlinearity(h) + h = self.conv1(h) - h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + if not self.pre_norm: + h = self.norm1(h) + h = self.nonlinearity(h) + h = h * mask if mask is not None else h + + h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None] + + h = h * mask if mask is not None else h + if self.pre_norm: + h = self.norm2(h) + h = self.nonlinearity(h) - h = self.norm2(h) - h = nonlinearity(h) h = self.dropout(h) h = self.conv2(h) + if not self.pre_norm: + h = self.norm2(h) + h = self.nonlinearity(h) + h = h * mask if mask is not None else h + + x = x * mask if mask is not None else x if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) @@ -383,58 +461,17 @@ class ResnetBlock(nn.Module): return x + h -# unet_grad_tts.py -class ResnetBlockGradTTS(torch.nn.Module): - def __init__(self, dim, dim_out, time_emb_dim, groups=8): - super(ResnetBlockGradTTS, self).__init__() - self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out)) - - self.block1 = Block(dim, dim_out, groups=groups) - self.block2 = Block(dim_out, dim_out, groups=groups) - if dim != dim_out: - self.res_conv = torch.nn.Conv2d(dim, dim_out, 1) - else: - self.res_conv = torch.nn.Identity() - - def forward(self, x, mask, time_emb): - h = self.block1(x, mask) - h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1) - h = self.block2(h, mask) - output = h + self.res_conv(x * mask) - return output - - -# unet_rl.py -class ResidualTemporalBlock(nn.Module): - def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5): - super().__init__() - - self.blocks = nn.ModuleList( - [ - Conv1dBlock(inp_channels, out_channels, kernel_size), - Conv1dBlock(out_channels, out_channels, kernel_size), - ] +# TODO(Patrick) - just there to convert the weights; can delete afterward +class Block(torch.nn.Module): + def __init__(self, dim, dim_out, groups=8): + super(Block, self).__init__() + self.block = torch.nn.Sequential( + torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish() ) - self.time_mlp = nn.Sequential( - nn.Mish(), - nn.Linear(embed_dim, out_channels), - RearrangeDim(), - # Rearrange("batch t -> batch t 1"), - ) - - self.residual_conv = ( - nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() - ) - - def forward(self, x, t): - """ - x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x - out_channels x horizon ] - """ - out = self.blocks[0](x) + self.time_mlp(t) - out = self.blocks[1](out) - return out + self.residual_conv(x) + def forward(self, x, mask): + output = self.block(x * mask) + return output * mask # unet_score_estimation.py @@ -570,6 +607,39 @@ class ResnetBlockDDPMpp(nn.Module): return (x + h) / np.sqrt(2.0) +# unet_rl.py +class ResidualTemporalBlock(nn.Module): + def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5): + super().__init__() + + self.blocks = nn.ModuleList( + [ + Conv1dBlock(inp_channels, out_channels, kernel_size), + Conv1dBlock(out_channels, out_channels, kernel_size), + ] + ) + + self.time_mlp = nn.Sequential( + nn.Mish(), + nn.Linear(embed_dim, out_channels), + RearrangeDim(), + # Rearrange("batch t -> batch t 1"), + ) + + self.residual_conv = ( + nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() + ) + + def forward(self, x, t): + """ + x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x + out_channels x horizon ] + """ + out = self.blocks[0](x) + self.time_mlp(t) + out = self.blocks[1](out) + return out + self.residual_conv(x) + + # HELPER Modules @@ -617,18 +687,6 @@ class Mish(torch.nn.Module): return x * torch.tanh(torch.nn.functional.softplus(x)) -class Block(torch.nn.Module): - def __init__(self, dim, dim_out, groups=8): - super(Block, self).__init__() - self.block = torch.nn.Sequential( - torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish() - ) - - def forward(self, x, mask): - output = self.block(x * mask) - return output * mask - - class Conv1dBlock(nn.Module): """ Conv1d --> GroupNorm --> Mish diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index e6dabe6dd4..27c6264cfa 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -4,9 +4,7 @@ from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .attention import LinearAttention from .embeddings import get_timestep_embedding -from .resnet import Downsample -from .resnet import ResnetBlockGradTTS as ResnetBlock -from .resnet import Upsample +from .resnet import Downsample, ResnetBlock, Upsample class Mish(torch.nn.Module): @@ -86,8 +84,26 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): self.downs.append( torch.nn.ModuleList( [ - ResnetBlock(dim_in, dim_out, time_emb_dim=dim), - ResnetBlock(dim_out, dim_out, time_emb_dim=dim), + ResnetBlock( + in_channels=dim_in, + out_channels=dim_out, + temb_channels=dim, + groups=8, + pre_norm=False, + eps=1e-5, + non_linearity="mish", + overwrite_for_grad_tts=True, + ), + ResnetBlock( + in_channels=dim_out, + out_channels=dim_out, + temb_channels=dim, + groups=8, + pre_norm=False, + eps=1e-5, + non_linearity="mish", + overwrite_for_grad_tts=True, + ), Residual(Rezero(LinearAttention(dim_out))), Downsample(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(), ] @@ -95,16 +111,52 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): ) mid_dim = dims[-1] - self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) + self.mid_block1 = ResnetBlock( + in_channels=mid_dim, + out_channels=mid_dim, + temb_channels=dim, + groups=8, + pre_norm=False, + eps=1e-5, + non_linearity="mish", + overwrite_for_grad_tts=True, + ) self.mid_attn = Residual(Rezero(LinearAttention(mid_dim))) - self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim) + self.mid_block2 = ResnetBlock( + in_channels=mid_dim, + out_channels=mid_dim, + temb_channels=dim, + groups=8, + pre_norm=False, + eps=1e-5, + non_linearity="mish", + overwrite_for_grad_tts=True, + ) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): self.ups.append( torch.nn.ModuleList( [ - ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim), - ResnetBlock(dim_in, dim_in, time_emb_dim=dim), + ResnetBlock( + in_channels=dim_out * 2, + out_channels=dim_in, + temb_channels=dim, + groups=8, + pre_norm=False, + eps=1e-5, + non_linearity="mish", + overwrite_for_grad_tts=True, + ), + ResnetBlock( + in_channels=dim_in, + out_channels=dim_in, + temb_channels=dim, + groups=8, + pre_norm=False, + eps=1e-5, + non_linearity="mish", + overwrite_for_grad_tts=True, + ), Residual(Rezero(LinearAttention(dim_in))), Upsample(dim_in, use_conv_transpose=True), ] @@ -135,8 +187,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): masks = [mask] for resnet1, resnet2, attn, downsample in self.downs: mask_down = masks[-1] - x = resnet1(x, mask_down, t) - x = resnet2(x, mask_down, t) + x = resnet1(x, t, mask_down) + x = resnet2(x, t, mask_down) x = attn(x) hiddens.append(x) x = downsample(x * mask_down) @@ -144,15 +196,15 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): masks = masks[:-1] mask_mid = masks[-1] - x = self.mid_block1(x, mask_mid, t) + x = self.mid_block1(x, t, mask_mid) x = self.mid_attn(x) - x = self.mid_block2(x, mask_mid, t) + x = self.mid_block2(x, t, mask_mid) for resnet1, resnet2, attn, upsample in self.ups: mask_up = masks.pop() x = torch.cat((x, hiddens.pop()), dim=1) - x = resnet1(x, mask_up, t) - x = resnet2(x, mask_up, t) + x = resnet1(x, t, mask_up) + x = resnet2(x, t, mask_up) x = attn(x) x = upsample(x * mask_up)