mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
[Clean up] Clean unused code (#245)
* CleanResNet * refactor more * correct
This commit is contained in:
committed by
GitHub
parent
47893164ab
commit
c1efda70b5
@@ -390,7 +390,7 @@ class ModelMixin(torch.nn.Module):
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}."
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import math
|
||||
from inspect import isfunction
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class AttentionBlockNew(nn.Module):
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
|
||||
to the N-d case.
|
||||
@@ -82,55 +81,6 @@ class AttentionBlockNew(nn.Module):
|
||||
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
||||
return hidden_states
|
||||
|
||||
def set_weight(self, attn_layer):
|
||||
self.group_norm.weight.data = attn_layer.norm.weight.data
|
||||
self.group_norm.bias.data = attn_layer.norm.bias.data
|
||||
|
||||
if hasattr(attn_layer, "q"):
|
||||
self.query.weight.data = attn_layer.q.weight.data[:, :, 0, 0]
|
||||
self.key.weight.data = attn_layer.k.weight.data[:, :, 0, 0]
|
||||
self.value.weight.data = attn_layer.v.weight.data[:, :, 0, 0]
|
||||
|
||||
self.query.bias.data = attn_layer.q.bias.data
|
||||
self.key.bias.data = attn_layer.k.bias.data
|
||||
self.value.bias.data = attn_layer.v.bias.data
|
||||
|
||||
self.proj_attn.weight.data = attn_layer.proj_out.weight.data[:, :, 0, 0]
|
||||
self.proj_attn.bias.data = attn_layer.proj_out.bias.data
|
||||
elif hasattr(attn_layer, "NIN_0"):
|
||||
self.query.weight.data = attn_layer.NIN_0.W.data.T
|
||||
self.key.weight.data = attn_layer.NIN_1.W.data.T
|
||||
self.value.weight.data = attn_layer.NIN_2.W.data.T
|
||||
|
||||
self.query.bias.data = attn_layer.NIN_0.b.data
|
||||
self.key.bias.data = attn_layer.NIN_1.b.data
|
||||
self.value.bias.data = attn_layer.NIN_2.b.data
|
||||
|
||||
self.proj_attn.weight.data = attn_layer.NIN_3.W.data.T
|
||||
self.proj_attn.bias.data = attn_layer.NIN_3.b.data
|
||||
|
||||
self.group_norm.weight.data = attn_layer.GroupNorm_0.weight.data
|
||||
self.group_norm.bias.data = attn_layer.GroupNorm_0.bias.data
|
||||
else:
|
||||
qkv_weight = attn_layer.qkv.weight.data.reshape(
|
||||
self.num_heads, 3 * self.channels // self.num_heads, self.channels
|
||||
)
|
||||
qkv_bias = attn_layer.qkv.bias.data.reshape(self.num_heads, 3 * self.channels // self.num_heads)
|
||||
|
||||
q_w, k_w, v_w = qkv_weight.split(self.channels // self.num_heads, dim=1)
|
||||
q_b, k_b, v_b = qkv_bias.split(self.channels // self.num_heads, dim=1)
|
||||
|
||||
self.query.weight.data = q_w.reshape(-1, self.channels)
|
||||
self.key.weight.data = k_w.reshape(-1, self.channels)
|
||||
self.value.weight.data = v_w.reshape(-1, self.channels)
|
||||
|
||||
self.query.bias.data = q_b.reshape(-1)
|
||||
self.key.bias.data = k_b.reshape(-1)
|
||||
self.value.bias.data = v_b.reshape(-1)
|
||||
|
||||
self.proj_attn.weight.data = attn_layer.proj.weight.data[:, :, 0]
|
||||
self.proj_attn.bias.data = attn_layer.proj.bias.data
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
@@ -170,12 +120,6 @@ class SpatialTransformer(nn.Module):
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
|
||||
def set_weight(self, layer):
|
||||
self.norm = layer.norm
|
||||
self.proj_in = layer.proj_in
|
||||
self.transformer_blocks = layer.transformer_blocks
|
||||
self.proj_out = layer.proj_out
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
|
||||
@@ -203,7 +147,7 @@ class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
context_dim = context_dim if context_dim is not None else query_dim
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
@@ -234,7 +178,7 @@ class CrossAttention(nn.Module):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
context = context if context is not None else x
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
@@ -244,7 +188,7 @@ class CrossAttention(nn.Module):
|
||||
|
||||
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
|
||||
|
||||
if exists(mask):
|
||||
if mask is not None:
|
||||
mask = mask.reshape(batch_size, -1)
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = mask[:, None, :].repeat(h, 1, 1)
|
||||
@@ -262,8 +206,8 @@ class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
project_in = GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
||||
|
||||
@@ -280,155 +224,3 @@ class GEGLU(nn.Module):
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
|
||||
class NIN(nn.Module):
|
||||
def __init__(self, in_dim, num_units, init_scale=0.1):
|
||||
super().__init__()
|
||||
self.W = nn.Parameter(torch.zeros(in_dim, num_units), requires_grad=True)
|
||||
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
# the main attention block that is used for all models
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other.
|
||||
|
||||
Originally ported from here, but adapted to the N-d case.
|
||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
num_heads=1,
|
||||
num_head_channels=None,
|
||||
num_groups=32,
|
||||
encoder_channels=None,
|
||||
overwrite_qkv=False,
|
||||
overwrite_linear=False,
|
||||
rescale_output_factor=1.0,
|
||||
eps=1e-5,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
if num_head_channels is None:
|
||||
self.num_heads = num_heads
|
||||
else:
|
||||
assert (
|
||||
channels % num_head_channels == 0
|
||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
self.num_heads = channels // num_head_channels
|
||||
|
||||
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
|
||||
self.qkv = nn.Conv1d(channels, channels * 3, 1)
|
||||
self.n_heads = self.num_heads
|
||||
self.rescale_output_factor = rescale_output_factor
|
||||
|
||||
if encoder_channels is not None:
|
||||
self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
|
||||
|
||||
self.proj = nn.Conv1d(channels, channels, 1)
|
||||
|
||||
self.overwrite_qkv = overwrite_qkv
|
||||
self.overwrite_linear = overwrite_linear
|
||||
|
||||
if overwrite_qkv:
|
||||
in_channels = channels
|
||||
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
|
||||
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
elif self.overwrite_linear:
|
||||
num_groups = min(channels // 4, 32)
|
||||
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=1e-6)
|
||||
self.NIN_0 = NIN(channels, channels)
|
||||
self.NIN_1 = NIN(channels, channels)
|
||||
self.NIN_2 = NIN(channels, channels)
|
||||
self.NIN_3 = NIN(channels, channels)
|
||||
|
||||
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=1e-6)
|
||||
else:
|
||||
self.proj_out = nn.Conv1d(channels, channels, 1)
|
||||
self.set_weights(self)
|
||||
|
||||
self.is_overwritten = False
|
||||
|
||||
def set_weights(self, module):
|
||||
if self.overwrite_qkv:
|
||||
qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[
|
||||
:, :, :, 0
|
||||
]
|
||||
qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0)
|
||||
|
||||
self.qkv.weight.data = qkv_weight
|
||||
self.qkv.bias.data = qkv_bias
|
||||
|
||||
proj_out = nn.Conv1d(self.channels, self.channels, 1)
|
||||
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
|
||||
proj_out.bias.data = module.proj_out.bias.data
|
||||
|
||||
self.proj = proj_out
|
||||
elif self.overwrite_linear:
|
||||
self.qkv.weight.data = torch.concat(
|
||||
[self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0
|
||||
)[:, :, None]
|
||||
self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
|
||||
|
||||
self.proj.weight.data = self.NIN_3.W.data.T[:, :, None]
|
||||
self.proj.bias.data = self.NIN_3.b.data
|
||||
|
||||
self.norm.weight.data = self.GroupNorm_0.weight.data
|
||||
self.norm.bias.data = self.GroupNorm_0.bias.data
|
||||
else:
|
||||
self.proj.weight.data = self.proj_out.weight.data
|
||||
self.proj.bias.data = self.proj_out.bias.data
|
||||
|
||||
def forward(self, x, encoder_out=None):
|
||||
if not self.is_overwritten and (self.overwrite_qkv or self.overwrite_linear):
|
||||
self.set_weights(self)
|
||||
self.is_overwritten = True
|
||||
|
||||
b, c, *spatial = x.shape
|
||||
hid_states = self.norm(x).view(b, c, -1)
|
||||
|
||||
qkv = self.qkv(hid_states)
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||
|
||||
if encoder_out is not None:
|
||||
encoder_kv = self.encoder_kv(encoder_out)
|
||||
assert encoder_kv.shape[1] == self.n_heads * ch * 2
|
||||
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
|
||||
k = torch.cat([ek, k], dim=-1)
|
||||
v = torch.cat([ev, v], dim=-1)
|
||||
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
|
||||
a = torch.einsum("bts,bcs->bct", weight, v)
|
||||
h = a.reshape(bs, -1, length)
|
||||
|
||||
h = self.proj(h)
|
||||
h = h.reshape(b, c, *spatial)
|
||||
|
||||
result = x + h
|
||||
|
||||
result = result / self.rescale_output_factor
|
||||
|
||||
return result
|
||||
|
||||
@@ -248,7 +248,7 @@ class FirDownsample2D(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
class ResnetBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -17,8 +17,8 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .attention import AttentionBlockNew, SpatialTransformer
|
||||
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock, Upsample2D
|
||||
from .attention import AttentionBlock, SpatialTransformer
|
||||
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
|
||||
|
||||
|
||||
def get_down_block(
|
||||
@@ -219,7 +219,7 @@ class UNetMidBlock2D(nn.Module):
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
ResnetBlock(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
@@ -236,7 +236,7 @@ class UNetMidBlock2D(nn.Module):
|
||||
|
||||
for _ in range(num_layers):
|
||||
attentions.append(
|
||||
AttentionBlockNew(
|
||||
AttentionBlock(
|
||||
in_channels,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
@@ -245,7 +245,7 @@ class UNetMidBlock2D(nn.Module):
|
||||
)
|
||||
)
|
||||
resnets.append(
|
||||
ResnetBlock(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
@@ -299,7 +299,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
ResnetBlock(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
@@ -325,7 +325,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
)
|
||||
)
|
||||
resnets.append(
|
||||
ResnetBlock(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
@@ -379,7 +379,7 @@ class AttnDownBlock2D(nn.Module):
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlock(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
@@ -393,7 +393,7 @@ class AttnDownBlock2D(nn.Module):
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
AttentionBlockNew(
|
||||
AttentionBlock(
|
||||
out_channels,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
@@ -461,7 +461,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlock(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
@@ -537,7 +537,7 @@ class DownBlock2D(nn.Module):
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlock(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
@@ -602,7 +602,7 @@ class DownEncoderBlock2D(nn.Module):
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlock(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=None,
|
||||
@@ -664,7 +664,7 @@ class AttnDownEncoderBlock2D(nn.Module):
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlock(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=None,
|
||||
@@ -678,7 +678,7 @@ class AttnDownEncoderBlock2D(nn.Module):
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
AttentionBlockNew(
|
||||
AttentionBlock(
|
||||
out_channels,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
@@ -740,7 +740,7 @@ class AttnSkipDownBlock2D(nn.Module):
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
self.resnets.append(
|
||||
ResnetBlock(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
@@ -755,7 +755,7 @@ class AttnSkipDownBlock2D(nn.Module):
|
||||
)
|
||||
)
|
||||
self.attentions.append(
|
||||
AttentionBlockNew(
|
||||
AttentionBlock(
|
||||
out_channels,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
@@ -764,7 +764,7 @@ class AttnSkipDownBlock2D(nn.Module):
|
||||
)
|
||||
|
||||
if add_downsample:
|
||||
self.resnet_down = ResnetBlock(
|
||||
self.resnet_down = ResnetBlock2D(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
@@ -828,7 +828,7 @@ class SkipDownBlock2D(nn.Module):
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
self.resnets.append(
|
||||
ResnetBlock(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
@@ -844,7 +844,7 @@ class SkipDownBlock2D(nn.Module):
|
||||
)
|
||||
|
||||
if add_downsample:
|
||||
self.resnet_down = ResnetBlock(
|
||||
self.resnet_down = ResnetBlock2D(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
@@ -915,7 +915,7 @@ class AttnUpBlock2D(nn.Module):
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock(
|
||||
ResnetBlock2D(
|
||||
in_channels=resnet_in_channels + res_skip_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
@@ -929,7 +929,7 @@ class AttnUpBlock2D(nn.Module):
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
AttentionBlockNew(
|
||||
AttentionBlock(
|
||||
out_channels,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
@@ -995,7 +995,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock(
|
||||
ResnetBlock2D(
|
||||
in_channels=resnet_in_channels + res_skip_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
@@ -1068,7 +1068,7 @@ class UpBlock2D(nn.Module):
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock(
|
||||
ResnetBlock2D(
|
||||
in_channels=resnet_in_channels + res_skip_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
@@ -1128,7 +1128,7 @@ class UpDecoderBlock2D(nn.Module):
|
||||
input_channels = in_channels if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock(
|
||||
ResnetBlock2D(
|
||||
in_channels=input_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=None,
|
||||
@@ -1184,7 +1184,7 @@ class AttnUpDecoderBlock2D(nn.Module):
|
||||
input_channels = in_channels if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock(
|
||||
ResnetBlock2D(
|
||||
in_channels=input_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=None,
|
||||
@@ -1198,7 +1198,7 @@ class AttnUpDecoderBlock2D(nn.Module):
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
AttentionBlockNew(
|
||||
AttentionBlock(
|
||||
out_channels,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
@@ -1257,7 +1257,7 @@ class AttnSkipUpBlock2D(nn.Module):
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
|
||||
self.resnets.append(
|
||||
ResnetBlock(
|
||||
ResnetBlock2D(
|
||||
in_channels=resnet_in_channels + res_skip_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
@@ -1273,7 +1273,7 @@ class AttnSkipUpBlock2D(nn.Module):
|
||||
)
|
||||
|
||||
self.attentions.append(
|
||||
AttentionBlockNew(
|
||||
AttentionBlock(
|
||||
out_channels,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
@@ -1283,7 +1283,7 @@ class AttnSkipUpBlock2D(nn.Module):
|
||||
|
||||
self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
|
||||
if add_upsample:
|
||||
self.resnet_up = ResnetBlock(
|
||||
self.resnet_up = ResnetBlock2D(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
@@ -1363,7 +1363,7 @@ class SkipUpBlock2D(nn.Module):
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
|
||||
self.resnets.append(
|
||||
ResnetBlock(
|
||||
ResnetBlock2D(
|
||||
in_channels=resnet_in_channels + res_skip_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
@@ -1380,7 +1380,7 @@ class SkipUpBlock2D(nn.Module):
|
||||
|
||||
self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
|
||||
if add_upsample:
|
||||
self.resnet_up = ResnetBlock(
|
||||
self.resnet_up = ResnetBlock2D(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
|
||||
Reference in New Issue
Block a user