1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

[Clean up] Clean unused code (#245)

* CleanResNet

* refactor more

* correct
This commit is contained in:
Patrick von Platen
2022-08-25 11:55:57 +02:00
committed by GitHub
parent 47893164ab
commit c1efda70b5
4 changed files with 39 additions and 247 deletions

View File

@@ -390,7 +390,7 @@ class ModelMixin(torch.nn.Module):
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}."
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}."
)
except HTTPError as err:
raise EnvironmentError(

View File

@@ -1,12 +1,11 @@
import math
from inspect import isfunction
import torch
import torch.nn.functional as F
from torch import nn
class AttentionBlockNew(nn.Module):
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
@@ -82,55 +81,6 @@ class AttentionBlockNew(nn.Module):
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def set_weight(self, attn_layer):
self.group_norm.weight.data = attn_layer.norm.weight.data
self.group_norm.bias.data = attn_layer.norm.bias.data
if hasattr(attn_layer, "q"):
self.query.weight.data = attn_layer.q.weight.data[:, :, 0, 0]
self.key.weight.data = attn_layer.k.weight.data[:, :, 0, 0]
self.value.weight.data = attn_layer.v.weight.data[:, :, 0, 0]
self.query.bias.data = attn_layer.q.bias.data
self.key.bias.data = attn_layer.k.bias.data
self.value.bias.data = attn_layer.v.bias.data
self.proj_attn.weight.data = attn_layer.proj_out.weight.data[:, :, 0, 0]
self.proj_attn.bias.data = attn_layer.proj_out.bias.data
elif hasattr(attn_layer, "NIN_0"):
self.query.weight.data = attn_layer.NIN_0.W.data.T
self.key.weight.data = attn_layer.NIN_1.W.data.T
self.value.weight.data = attn_layer.NIN_2.W.data.T
self.query.bias.data = attn_layer.NIN_0.b.data
self.key.bias.data = attn_layer.NIN_1.b.data
self.value.bias.data = attn_layer.NIN_2.b.data
self.proj_attn.weight.data = attn_layer.NIN_3.W.data.T
self.proj_attn.bias.data = attn_layer.NIN_3.b.data
self.group_norm.weight.data = attn_layer.GroupNorm_0.weight.data
self.group_norm.bias.data = attn_layer.GroupNorm_0.bias.data
else:
qkv_weight = attn_layer.qkv.weight.data.reshape(
self.num_heads, 3 * self.channels // self.num_heads, self.channels
)
qkv_bias = attn_layer.qkv.bias.data.reshape(self.num_heads, 3 * self.channels // self.num_heads)
q_w, k_w, v_w = qkv_weight.split(self.channels // self.num_heads, dim=1)
q_b, k_b, v_b = qkv_bias.split(self.channels // self.num_heads, dim=1)
self.query.weight.data = q_w.reshape(-1, self.channels)
self.key.weight.data = k_w.reshape(-1, self.channels)
self.value.weight.data = v_w.reshape(-1, self.channels)
self.query.bias.data = q_b.reshape(-1)
self.key.bias.data = k_b.reshape(-1)
self.value.bias.data = v_b.reshape(-1)
self.proj_attn.weight.data = attn_layer.proj.weight.data[:, :, 0]
self.proj_attn.bias.data = attn_layer.proj.bias.data
class SpatialTransformer(nn.Module):
"""
@@ -170,12 +120,6 @@ class SpatialTransformer(nn.Module):
x = self.proj_out(x)
return x + x_in
def set_weight(self, layer):
self.norm = layer.norm
self.proj_in = layer.proj_in
self.transformer_blocks = layer.transformer_blocks
self.proj_out = layer.proj_out
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
@@ -203,7 +147,7 @@ class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
context_dim = context_dim if context_dim is not None else query_dim
self.scale = dim_head**-0.5
self.heads = heads
@@ -234,7 +178,7 @@ class CrossAttention(nn.Module):
h = self.heads
q = self.to_q(x)
context = default(context, x)
context = context if context is not None else x
k = self.to_k(context)
v = self.to_v(context)
@@ -244,7 +188,7 @@ class CrossAttention(nn.Module):
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask):
if mask is not None:
mask = mask.reshape(batch_size, -1)
max_neg_value = -torch.finfo(sim.dtype).max
mask = mask[:, None, :].repeat(h, 1, 1)
@@ -262,8 +206,8 @@ class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
dim_out = dim_out if dim_out is not None else dim
project_in = GEGLU(dim, inner_dim)
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
@@ -280,155 +224,3 @@ class GEGLU(nn.Module):
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
class NIN(nn.Module):
def __init__(self, in_dim, num_units, init_scale=0.1):
super().__init__()
self.W = nn.Parameter(torch.zeros(in_dim, num_units), requires_grad=True)
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
# the main attention block that is used for all models
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=None,
num_groups=32,
encoder_channels=None,
overwrite_qkv=False,
overwrite_linear=False,
rescale_output_factor=1.0,
eps=1e-5,
):
super().__init__()
self.channels = channels
if num_head_channels is None:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.n_heads = self.num_heads
self.rescale_output_factor = rescale_output_factor
if encoder_channels is not None:
self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
self.proj = nn.Conv1d(channels, channels, 1)
self.overwrite_qkv = overwrite_qkv
self.overwrite_linear = overwrite_linear
if overwrite_qkv:
in_channels = channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
elif self.overwrite_linear:
num_groups = min(channels // 4, 32)
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
self.NIN_0 = NIN(channels, channels)
self.NIN_1 = NIN(channels, channels)
self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels)
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
else:
self.proj_out = nn.Conv1d(channels, channels, 1)
self.set_weights(self)
self.is_overwritten = False
def set_weights(self, module):
if self.overwrite_qkv:
qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[
:, :, :, 0
]
qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0)
self.qkv.weight.data = qkv_weight
self.qkv.bias.data = qkv_bias
proj_out = nn.Conv1d(self.channels, self.channels, 1)
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
proj_out.bias.data = module.proj_out.bias.data
self.proj = proj_out
elif self.overwrite_linear:
self.qkv.weight.data = torch.concat(
[self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0
)[:, :, None]
self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
self.proj.weight.data = self.NIN_3.W.data.T[:, :, None]
self.proj.bias.data = self.NIN_3.b.data
self.norm.weight.data = self.GroupNorm_0.weight.data
self.norm.bias.data = self.GroupNorm_0.bias.data
else:
self.proj.weight.data = self.proj_out.weight.data
self.proj.bias.data = self.proj_out.bias.data
def forward(self, x, encoder_out=None):
if not self.is_overwritten and (self.overwrite_qkv or self.overwrite_linear):
self.set_weights(self)
self.is_overwritten = True
b, c, *spatial = x.shape
hid_states = self.norm(x).view(b, c, -1)
qkv = self.qkv(hid_states)
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
if encoder_out is not None:
encoder_kv = self.encoder_kv(encoder_out)
assert encoder_kv.shape[1] == self.n_heads * ch * 2
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
k = torch.cat([ek, k], dim=-1)
v = torch.cat([ev, v], dim=-1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)
h = a.reshape(bs, -1, length)
h = self.proj(h)
h = h.reshape(b, c, *spatial)
result = x + h
result = result / self.rescale_output_factor
return result

View File

@@ -248,7 +248,7 @@ class FirDownsample2D(nn.Module):
return x
class ResnetBlock(nn.Module):
class ResnetBlock2D(nn.Module):
def __init__(
self,
*,

View File

@@ -17,8 +17,8 @@ import numpy as np
import torch
from torch import nn
from .attention import AttentionBlockNew, SpatialTransformer
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock, Upsample2D
from .attention import AttentionBlock, SpatialTransformer
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
def get_down_block(
@@ -219,7 +219,7 @@ class UNetMidBlock2D(nn.Module):
# there is always at least one resnet
resnets = [
ResnetBlock(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
@@ -236,7 +236,7 @@ class UNetMidBlock2D(nn.Module):
for _ in range(num_layers):
attentions.append(
AttentionBlockNew(
AttentionBlock(
in_channels,
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
@@ -245,7 +245,7 @@ class UNetMidBlock2D(nn.Module):
)
)
resnets.append(
ResnetBlock(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
@@ -299,7 +299,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
# there is always at least one resnet
resnets = [
ResnetBlock(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
@@ -325,7 +325,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
)
)
resnets.append(
ResnetBlock(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
@@ -379,7 +379,7 @@ class AttnDownBlock2D(nn.Module):
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
@@ -393,7 +393,7 @@ class AttnDownBlock2D(nn.Module):
)
)
attentions.append(
AttentionBlockNew(
AttentionBlock(
out_channels,
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
@@ -461,7 +461,7 @@ class CrossAttnDownBlock2D(nn.Module):
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
@@ -537,7 +537,7 @@ class DownBlock2D(nn.Module):
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
@@ -602,7 +602,7 @@ class DownEncoderBlock2D(nn.Module):
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=None,
@@ -664,7 +664,7 @@ class AttnDownEncoderBlock2D(nn.Module):
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=None,
@@ -678,7 +678,7 @@ class AttnDownEncoderBlock2D(nn.Module):
)
)
attentions.append(
AttentionBlockNew(
AttentionBlock(
out_channels,
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
@@ -740,7 +740,7 @@ class AttnSkipDownBlock2D(nn.Module):
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.resnets.append(
ResnetBlock(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
@@ -755,7 +755,7 @@ class AttnSkipDownBlock2D(nn.Module):
)
)
self.attentions.append(
AttentionBlockNew(
AttentionBlock(
out_channels,
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
@@ -764,7 +764,7 @@ class AttnSkipDownBlock2D(nn.Module):
)
if add_downsample:
self.resnet_down = ResnetBlock(
self.resnet_down = ResnetBlock2D(
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
@@ -828,7 +828,7 @@ class SkipDownBlock2D(nn.Module):
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.resnets.append(
ResnetBlock(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
@@ -844,7 +844,7 @@ class SkipDownBlock2D(nn.Module):
)
if add_downsample:
self.resnet_down = ResnetBlock(
self.resnet_down = ResnetBlock2D(
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
@@ -915,7 +915,7 @@ class AttnUpBlock2D(nn.Module):
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
@@ -929,7 +929,7 @@ class AttnUpBlock2D(nn.Module):
)
)
attentions.append(
AttentionBlockNew(
AttentionBlock(
out_channels,
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
@@ -995,7 +995,7 @@ class CrossAttnUpBlock2D(nn.Module):
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
@@ -1068,7 +1068,7 @@ class UpBlock2D(nn.Module):
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
@@ -1128,7 +1128,7 @@ class UpDecoderBlock2D(nn.Module):
input_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock(
ResnetBlock2D(
in_channels=input_channels,
out_channels=out_channels,
temb_channels=None,
@@ -1184,7 +1184,7 @@ class AttnUpDecoderBlock2D(nn.Module):
input_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock(
ResnetBlock2D(
in_channels=input_channels,
out_channels=out_channels,
temb_channels=None,
@@ -1198,7 +1198,7 @@ class AttnUpDecoderBlock2D(nn.Module):
)
)
attentions.append(
AttentionBlockNew(
AttentionBlock(
out_channels,
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
@@ -1257,7 +1257,7 @@ class AttnSkipUpBlock2D(nn.Module):
resnet_in_channels = prev_output_channel if i == 0 else out_channels
self.resnets.append(
ResnetBlock(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
@@ -1273,7 +1273,7 @@ class AttnSkipUpBlock2D(nn.Module):
)
self.attentions.append(
AttentionBlockNew(
AttentionBlock(
out_channels,
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
@@ -1283,7 +1283,7 @@ class AttnSkipUpBlock2D(nn.Module):
self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
if add_upsample:
self.resnet_up = ResnetBlock(
self.resnet_up = ResnetBlock2D(
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
@@ -1363,7 +1363,7 @@ class SkipUpBlock2D(nn.Module):
resnet_in_channels = prev_output_channel if i == 0 else out_channels
self.resnets.append(
ResnetBlock(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
@@ -1380,7 +1380,7 @@ class SkipUpBlock2D(nn.Module):
self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
if add_upsample:
self.resnet_up = ResnetBlock(
self.resnet_up = ResnetBlock2D(
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,