From c1efda70b52dc05857ad214106754d5e2028fc26 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 25 Aug 2022 11:55:57 +0200 Subject: [PATCH] [Clean up] Clean unused code (#245) * CleanResNet * refactor more * correct --- src/diffusers/modeling_utils.py | 2 +- src/diffusers/models/attention.py | 220 +--------------------------- src/diffusers/models/resnet.py | 2 +- src/diffusers/models/unet_blocks.py | 62 ++++---- 4 files changed, 39 insertions(+), 247 deletions(-) diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 1ecf27a2a9..a8489a7b95 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -390,7 +390,7 @@ class ModelMixin(torch.nn.Module): ) except EntryNotFoundError: raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}." + f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}." ) except HTTPError as err: raise EnvironmentError( diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index dd22cdbb95..8d52ee9bde 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -1,12 +1,11 @@ import math -from inspect import isfunction import torch import torch.nn.functional as F from torch import nn -class AttentionBlockNew(nn.Module): +class AttentionBlock(nn.Module): """ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted to the N-d case. @@ -82,55 +81,6 @@ class AttentionBlockNew(nn.Module): hidden_states = (hidden_states + residual) / self.rescale_output_factor return hidden_states - def set_weight(self, attn_layer): - self.group_norm.weight.data = attn_layer.norm.weight.data - self.group_norm.bias.data = attn_layer.norm.bias.data - - if hasattr(attn_layer, "q"): - self.query.weight.data = attn_layer.q.weight.data[:, :, 0, 0] - self.key.weight.data = attn_layer.k.weight.data[:, :, 0, 0] - self.value.weight.data = attn_layer.v.weight.data[:, :, 0, 0] - - self.query.bias.data = attn_layer.q.bias.data - self.key.bias.data = attn_layer.k.bias.data - self.value.bias.data = attn_layer.v.bias.data - - self.proj_attn.weight.data = attn_layer.proj_out.weight.data[:, :, 0, 0] - self.proj_attn.bias.data = attn_layer.proj_out.bias.data - elif hasattr(attn_layer, "NIN_0"): - self.query.weight.data = attn_layer.NIN_0.W.data.T - self.key.weight.data = attn_layer.NIN_1.W.data.T - self.value.weight.data = attn_layer.NIN_2.W.data.T - - self.query.bias.data = attn_layer.NIN_0.b.data - self.key.bias.data = attn_layer.NIN_1.b.data - self.value.bias.data = attn_layer.NIN_2.b.data - - self.proj_attn.weight.data = attn_layer.NIN_3.W.data.T - self.proj_attn.bias.data = attn_layer.NIN_3.b.data - - self.group_norm.weight.data = attn_layer.GroupNorm_0.weight.data - self.group_norm.bias.data = attn_layer.GroupNorm_0.bias.data - else: - qkv_weight = attn_layer.qkv.weight.data.reshape( - self.num_heads, 3 * self.channels // self.num_heads, self.channels - ) - qkv_bias = attn_layer.qkv.bias.data.reshape(self.num_heads, 3 * self.channels // self.num_heads) - - q_w, k_w, v_w = qkv_weight.split(self.channels // self.num_heads, dim=1) - q_b, k_b, v_b = qkv_bias.split(self.channels // self.num_heads, dim=1) - - self.query.weight.data = q_w.reshape(-1, self.channels) - self.key.weight.data = k_w.reshape(-1, self.channels) - self.value.weight.data = v_w.reshape(-1, self.channels) - - self.query.bias.data = q_b.reshape(-1) - self.key.bias.data = k_b.reshape(-1) - self.value.bias.data = v_b.reshape(-1) - - self.proj_attn.weight.data = attn_layer.proj.weight.data[:, :, 0] - self.proj_attn.bias.data = attn_layer.proj.bias.data - class SpatialTransformer(nn.Module): """ @@ -170,12 +120,6 @@ class SpatialTransformer(nn.Module): x = self.proj_out(x) return x + x_in - def set_weight(self, layer): - self.norm = layer.norm - self.proj_in = layer.proj_in - self.transformer_blocks = layer.transformer_blocks - self.proj_out = layer.proj_out - class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True): @@ -203,7 +147,7 @@ class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): super().__init__() inner_dim = dim_head * heads - context_dim = default(context_dim, query_dim) + context_dim = context_dim if context_dim is not None else query_dim self.scale = dim_head**-0.5 self.heads = heads @@ -234,7 +178,7 @@ class CrossAttention(nn.Module): h = self.heads q = self.to_q(x) - context = default(context, x) + context = context if context is not None else x k = self.to_k(context) v = self.to_v(context) @@ -244,7 +188,7 @@ class CrossAttention(nn.Module): sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale - if exists(mask): + if mask is not None: mask = mask.reshape(batch_size, -1) max_neg_value = -torch.finfo(sim.dtype).max mask = mask[:, None, :].repeat(h, 1, 1) @@ -262,8 +206,8 @@ class FeedForward(nn.Module): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): super().__init__() inner_dim = int(dim * mult) - dim_out = default(dim_out, dim) - project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + dim_out = dim_out if dim_out is not None else dim + project_in = GEGLU(dim, inner_dim) self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) @@ -280,155 +224,3 @@ class GEGLU(nn.Module): def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) return x * F.gelu(gate) - - -# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then -class NIN(nn.Module): - def __init__(self, in_dim, num_units, init_scale=0.1): - super().__init__() - self.W = nn.Parameter(torch.zeros(in_dim, num_units), requires_grad=True) - self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True) - - -def exists(val): - return val is not None - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -# the main attention block that is used for all models -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. - - Originally ported from here, but adapted to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - """ - - def __init__( - self, - channels, - num_heads=1, - num_head_channels=None, - num_groups=32, - encoder_channels=None, - overwrite_qkv=False, - overwrite_linear=False, - rescale_output_factor=1.0, - eps=1e-5, - ): - super().__init__() - self.channels = channels - if num_head_channels is None: - self.num_heads = num_heads - else: - assert ( - channels % num_head_channels == 0 - ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" - self.num_heads = channels // num_head_channels - - self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True) - self.qkv = nn.Conv1d(channels, channels * 3, 1) - self.n_heads = self.num_heads - self.rescale_output_factor = rescale_output_factor - - if encoder_channels is not None: - self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1) - - self.proj = nn.Conv1d(channels, channels, 1) - - self.overwrite_qkv = overwrite_qkv - self.overwrite_linear = overwrite_linear - - if overwrite_qkv: - in_channels = channels - self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6) - self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - elif self.overwrite_linear: - num_groups = min(channels // 4, 32) - self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6) - self.NIN_0 = NIN(channels, channels) - self.NIN_1 = NIN(channels, channels) - self.NIN_2 = NIN(channels, channels) - self.NIN_3 = NIN(channels, channels) - - self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6) - else: - self.proj_out = nn.Conv1d(channels, channels, 1) - self.set_weights(self) - - self.is_overwritten = False - - def set_weights(self, module): - if self.overwrite_qkv: - qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[ - :, :, :, 0 - ] - qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0) - - self.qkv.weight.data = qkv_weight - self.qkv.bias.data = qkv_bias - - proj_out = nn.Conv1d(self.channels, self.channels, 1) - proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0] - proj_out.bias.data = module.proj_out.bias.data - - self.proj = proj_out - elif self.overwrite_linear: - self.qkv.weight.data = torch.concat( - [self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0 - )[:, :, None] - self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0) - - self.proj.weight.data = self.NIN_3.W.data.T[:, :, None] - self.proj.bias.data = self.NIN_3.b.data - - self.norm.weight.data = self.GroupNorm_0.weight.data - self.norm.bias.data = self.GroupNorm_0.bias.data - else: - self.proj.weight.data = self.proj_out.weight.data - self.proj.bias.data = self.proj_out.bias.data - - def forward(self, x, encoder_out=None): - if not self.is_overwritten and (self.overwrite_qkv or self.overwrite_linear): - self.set_weights(self) - self.is_overwritten = True - - b, c, *spatial = x.shape - hid_states = self.norm(x).view(b, c, -1) - - qkv = self.qkv(hid_states) - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) - - if encoder_out is not None: - encoder_kv = self.encoder_kv(encoder_out) - assert encoder_kv.shape[1] == self.n_heads * ch * 2 - ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1) - k = torch.cat([ek, k], dim=-1) - v = torch.cat([ev, v], dim=-1) - - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards - weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) - - a = torch.einsum("bts,bcs->bct", weight, v) - h = a.reshape(bs, -1, length) - - h = self.proj(h) - h = h.reshape(b, c, *spatial) - - result = x + h - - result = result / self.rescale_output_factor - - return result diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index c61aa27095..acce7b574e 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -248,7 +248,7 @@ class FirDownsample2D(nn.Module): return x -class ResnetBlock(nn.Module): +class ResnetBlock2D(nn.Module): def __init__( self, *, diff --git a/src/diffusers/models/unet_blocks.py b/src/diffusers/models/unet_blocks.py index 9d01949554..bf9e0198d7 100644 --- a/src/diffusers/models/unet_blocks.py +++ b/src/diffusers/models/unet_blocks.py @@ -17,8 +17,8 @@ import numpy as np import torch from torch import nn -from .attention import AttentionBlockNew, SpatialTransformer -from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock, Upsample2D +from .attention import AttentionBlock, SpatialTransformer +from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D def get_down_block( @@ -219,7 +219,7 @@ class UNetMidBlock2D(nn.Module): # there is always at least one resnet resnets = [ - ResnetBlock( + ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, @@ -236,7 +236,7 @@ class UNetMidBlock2D(nn.Module): for _ in range(num_layers): attentions.append( - AttentionBlockNew( + AttentionBlock( in_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, @@ -245,7 +245,7 @@ class UNetMidBlock2D(nn.Module): ) ) resnets.append( - ResnetBlock( + ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, @@ -299,7 +299,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): # there is always at least one resnet resnets = [ - ResnetBlock( + ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, @@ -325,7 +325,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): ) ) resnets.append( - ResnetBlock( + ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, @@ -379,7 +379,7 @@ class AttnDownBlock2D(nn.Module): for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( - ResnetBlock( + ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -393,7 +393,7 @@ class AttnDownBlock2D(nn.Module): ) ) attentions.append( - AttentionBlockNew( + AttentionBlock( out_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, @@ -461,7 +461,7 @@ class CrossAttnDownBlock2D(nn.Module): for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( - ResnetBlock( + ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -537,7 +537,7 @@ class DownBlock2D(nn.Module): for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( - ResnetBlock( + ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -602,7 +602,7 @@ class DownEncoderBlock2D(nn.Module): for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( - ResnetBlock( + ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=None, @@ -664,7 +664,7 @@ class AttnDownEncoderBlock2D(nn.Module): for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( - ResnetBlock( + ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=None, @@ -678,7 +678,7 @@ class AttnDownEncoderBlock2D(nn.Module): ) ) attentions.append( - AttentionBlockNew( + AttentionBlock( out_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, @@ -740,7 +740,7 @@ class AttnSkipDownBlock2D(nn.Module): for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels self.resnets.append( - ResnetBlock( + ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -755,7 +755,7 @@ class AttnSkipDownBlock2D(nn.Module): ) ) self.attentions.append( - AttentionBlockNew( + AttentionBlock( out_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, @@ -764,7 +764,7 @@ class AttnSkipDownBlock2D(nn.Module): ) if add_downsample: - self.resnet_down = ResnetBlock( + self.resnet_down = ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -828,7 +828,7 @@ class SkipDownBlock2D(nn.Module): for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels self.resnets.append( - ResnetBlock( + ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -844,7 +844,7 @@ class SkipDownBlock2D(nn.Module): ) if add_downsample: - self.resnet_down = ResnetBlock( + self.resnet_down = ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -915,7 +915,7 @@ class AttnUpBlock2D(nn.Module): resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - ResnetBlock( + ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -929,7 +929,7 @@ class AttnUpBlock2D(nn.Module): ) ) attentions.append( - AttentionBlockNew( + AttentionBlock( out_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, @@ -995,7 +995,7 @@ class CrossAttnUpBlock2D(nn.Module): resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - ResnetBlock( + ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1068,7 +1068,7 @@ class UpBlock2D(nn.Module): resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - ResnetBlock( + ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1128,7 +1128,7 @@ class UpDecoderBlock2D(nn.Module): input_channels = in_channels if i == 0 else out_channels resnets.append( - ResnetBlock( + ResnetBlock2D( in_channels=input_channels, out_channels=out_channels, temb_channels=None, @@ -1184,7 +1184,7 @@ class AttnUpDecoderBlock2D(nn.Module): input_channels = in_channels if i == 0 else out_channels resnets.append( - ResnetBlock( + ResnetBlock2D( in_channels=input_channels, out_channels=out_channels, temb_channels=None, @@ -1198,7 +1198,7 @@ class AttnUpDecoderBlock2D(nn.Module): ) ) attentions.append( - AttentionBlockNew( + AttentionBlock( out_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, @@ -1257,7 +1257,7 @@ class AttnSkipUpBlock2D(nn.Module): resnet_in_channels = prev_output_channel if i == 0 else out_channels self.resnets.append( - ResnetBlock( + ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1273,7 +1273,7 @@ class AttnSkipUpBlock2D(nn.Module): ) self.attentions.append( - AttentionBlockNew( + AttentionBlock( out_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, @@ -1283,7 +1283,7 @@ class AttnSkipUpBlock2D(nn.Module): self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) if add_upsample: - self.resnet_up = ResnetBlock( + self.resnet_up = ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1363,7 +1363,7 @@ class SkipUpBlock2D(nn.Module): resnet_in_channels = prev_output_channel if i == 0 else out_channels self.resnets.append( - ResnetBlock( + ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -1380,7 +1380,7 @@ class SkipUpBlock2D(nn.Module): self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) if add_upsample: - self.resnet_up = ResnetBlock( + self.resnet_up = ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels,