From c174bcf4bf330e2512004f94b0e42406ad49442c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 29 Jun 2022 14:35:18 +0000 Subject: [PATCH] finish --- src/diffusers/models/resnet.py | 179 +------------------------- src/diffusers/models/unet_grad_tts.py | 12 +- 2 files changed, 8 insertions(+), 183 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index cd0925cb67..1a45a0e3dc 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -340,49 +340,7 @@ class ResBlock(TimestepBlock): return self.skip_connection(x) + h -# unet.py -class OLD_ResnetBlock(nn.Module): - def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.temb_proj = torch.nn.Linear(temb_channels, out_channels) - self.norm2 = Normalize(out_channels) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) - - h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] - - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - - return x + h - - +# unet.py and unet_grad_tts.py class ResnetBlock(nn.Module): def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, temb_channels=512, groups=32, pre_norm=True, eps=1e-6, non_linearity="swish", overwrite_for_grad_tts=False): super().__init__() @@ -429,11 +387,6 @@ class ResnetBlock(nn.Module): else: self.res_conv = torch.nn.Identity() -# num_groups = 8 -# self.pre_norm = False -# eps = 1e-5 -# non_linearity = "mish" - def set_weights_grad_tts(self): self.conv1.weight.data = self.block1.block[0].weight.data self.conv1.bias.data = self.block1.block[0].bias.data @@ -453,11 +406,6 @@ class ResnetBlock(nn.Module): self.nin_shortcut.bias.data = self.res_conv.bias.data def forward(self, x, temb, mask=None): - if not self.pre_norm: - temp = mask - mask = temb - temb = temp - if self.overwrite_for_grad_tts and not self.is_overwritten: self.set_weights_grad_tts() self.is_overwritten = True @@ -500,130 +448,7 @@ class ResnetBlock(nn.Module): return x + h -# unet_grad_tts.py -class ResnetBlockGradTTS(torch.nn.Module): - def __init__(self, dim, dim_out, time_emb_dim, groups=8, eps=1e-6, overwrite=True, conv_shortcut=False, pre_norm=True): - super(ResnetBlockGradTTS, self).__init__() - self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, dim_out)) - self.pre_norm = pre_norm - - self.block1 = Block(dim, dim_out, groups=groups) - self.block2 = Block(dim_out, dim_out, groups=groups) - if dim != dim_out: - self.res_conv = torch.nn.Conv2d(dim, dim_out, 1) - else: - self.res_conv = torch.nn.Identity() - - self.overwrite = overwrite - if self.overwrite: - in_channels = dim - out_channels = dim_out - temb_channels = time_emb_dim - - # To set via init - self.pre_norm = False - eps = 1e-5 - - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - if self.pre_norm: - self.norm1 = Normalize(in_channels, num_groups=groups, eps=eps) - else: - self.norm1 = Normalize(out_channels, num_groups=groups, eps=eps) - - self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.temb_proj = torch.nn.Linear(temb_channels, out_channels) - self.norm2 = Normalize(out_channels, num_groups=groups, eps=eps) - dropout = 0.0 - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - - self.nonlinearity = Mish() - - self.is_overwritten = False - - def set_weights(self): - self.conv1.weight.data = self.block1.block[0].weight.data - self.conv1.bias.data = self.block1.block[0].bias.data - self.norm1.weight.data = self.block1.block[1].weight.data - self.norm1.bias.data = self.block1.block[1].bias.data - - self.conv2.weight.data = self.block2.block[0].weight.data - self.conv2.bias.data = self.block2.block[0].bias.data - self.norm2.weight.data = self.block2.block[1].weight.data - self.norm2.bias.data = self.block2.block[1].bias.data - - self.temb_proj.weight.data = self.mlp[1].weight.data - self.temb_proj.bias.data = self.mlp[1].bias.data - - if self.in_channels != self.out_channels: - self.nin_shortcut.weight.data = self.res_conv.weight.data - self.nin_shortcut.bias.data = self.res_conv.bias.data - - def forward(self, x, mask, time_emb): - h = self.block1(x, mask) - h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1) - h = self.block2(h, mask) - output = h + self.res_conv(x * mask) - - output = self.forward_2(x, time_emb, mask=mask) - return output - - def forward_2(self, x, temb, mask=None): - if not self.is_overwritten: - self.set_weights() - self.is_overwritten = True - - if mask is None: - mask = torch.ones_like(x) - - h = x - - h = h * mask - if self.pre_norm: - h = self.norm1(h) - h = self.nonlinearity(h) - - h = self.conv1(h) - - if not self.pre_norm: - h = self.norm1(h) - h = self.nonlinearity(h) - h = h * mask - - h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None] - - h = h * mask - if self.pre_norm: - h = self.norm2(h) - h = self.nonlinearity(h) - - h = self.dropout(h) - h = self.conv2(h) - - if not self.pre_norm: - h = self.norm2(h) - h = self.nonlinearity(h) - h = h * mask - - x = x * mask - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - - return x + h - - +# TODO(Patrick) - just there to convert the weights; can delete afterward class Block(torch.nn.Module): def __init__(self, dim, dim_out, groups=8): super(Block, self).__init__() diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index 56614f8d9d..8316fd3ce3 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -135,8 +135,8 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): masks = [mask] for resnet1, resnet2, attn, downsample in self.downs: mask_down = masks[-1] - x = resnet1(x, mask_down, t) - x = resnet2(x, mask_down, t) + x = resnet1(x, t, mask_down) + x = resnet2(x, t, mask_down) x = attn(x) hiddens.append(x) x = downsample(x * mask_down) @@ -144,15 +144,15 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): masks = masks[:-1] mask_mid = masks[-1] - x = self.mid_block1(x, mask_mid, t) + x = self.mid_block1(x, t, mask_mid) x = self.mid_attn(x) - x = self.mid_block2(x, mask_mid, t) + x = self.mid_block2(x, t, mask_mid) for resnet1, resnet2, attn, upsample in self.ups: mask_up = masks.pop() x = torch.cat((x, hiddens.pop()), dim=1) - x = resnet1(x, mask_up, t) - x = resnet2(x, mask_up, t) + x = resnet1(x, t, mask_up) + x = resnet2(x, t, mask_up) x = attn(x) x = upsample(x * mask_up)