1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

fix weird layer counting

This commit is contained in:
Nathan Lambert
2022-11-29 15:40:54 -08:00
parent 198fd951ec
commit dc2c3992d1

View File

@@ -434,7 +434,6 @@ class UNetMidBlock1D(nn.Module):
out_channels = in_channels if out_channels is None else out_channels
# there is always at least one resnet
self.down = KernelDownsample1D("cubic")
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
@@ -469,7 +468,7 @@ class UNetMidBlock1D(nn.Module):
class AttnDownBlock1D(nn.Module):
def __init__(self, out_channels: int, in_channels: int, num_layers: int = 2, mid_channels: int = None):
def __init__(self, out_channels: int, in_channels: int, num_layers: int = 3, mid_channels: int = None):
super().__init__()
if num_layers < 1:
@@ -478,16 +477,13 @@ class AttnDownBlock1D(nn.Module):
mid_channels = out_channels if mid_channels is None else mid_channels
self.down = KernelDownsample1D("cubic")
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
]
attentions = [
SelfAttention1d(mid_channels, mid_channels // 32),
]
resnets = []
attentions = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else mid_channels
if i < (num_layers - 1):
resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels))
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
attentions.append(SelfAttention1d(mid_channels, mid_channels // 32))
else:
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
@@ -507,7 +503,7 @@ class AttnDownBlock1D(nn.Module):
class DownBlock1D(nn.Module):
def __init__(self, out_channels: int, in_channels: int, mid_channels: int = None, num_layers: int = 2):
def __init__(self, out_channels: int, in_channels: int, mid_channels: int = None, num_layers: int = 3):
super().__init__()
if num_layers < 1:
raise ValueError("DownBlock1D requires added num_layers >= 1")
@@ -515,13 +511,12 @@ class DownBlock1D(nn.Module):
mid_channels = out_channels if mid_channels is None else mid_channels
self.down = KernelDownsample1D("cubic")
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
]
resnets = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else mid_channels
if i < (num_layers - 1):
resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels))
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
else:
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
@@ -537,20 +532,19 @@ class DownBlock1D(nn.Module):
class DownBlock1DNoSkip(nn.Module):
def __init__(self, out_channels: int, in_channels: int, mid_channels: int = None, num_layers: int = 2):
def __init__(self, out_channels: int, in_channels: int, mid_channels: int = None, num_layers: int = 3):
super().__init__()
if num_layers < 1:
raise ValueError("DownBlock1DNoSkip requires added num_layers >= 1")
mid_channels = out_channels if mid_channels is None else mid_channels
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
]
resnets = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else mid_channels
if i < (num_layers - 1):
resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels))
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
else:
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
@@ -565,23 +559,20 @@ class DownBlock1DNoSkip(nn.Module):
class AttnUpBlock1D(nn.Module):
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 2):
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 3):
super().__init__()
if num_layers < 1:
raise ValueError("AttnUpBlock1D requires added num_layers >= 1")
mid_channels = out_channels if mid_channels is None else mid_channels
resnets = [
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
]
attentions = [
SelfAttention1d(mid_channels, mid_channels // 32),
]
resnets = []
attentions = []
for i in range(num_layers):
in_channels = 2 * in_channels if i == 0 else mid_channels
if i < (num_layers - 1):
resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels))
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
attentions.append(SelfAttention1d(mid_channels, mid_channels // 32))
else:
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
@@ -605,20 +596,19 @@ class AttnUpBlock1D(nn.Module):
class UpBlock1D(nn.Module):
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 2):
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 3):
super().__init__()
if num_layers < 1:
raise ValueError("UpBlock1D requires added num_layers >= 1")
mid_channels = in_channels if mid_channels is None else mid_channels
resnets = [
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
]
resnets = []
for i in range(num_layers):
in_channels = 2 * in_channels if i == 0 else mid_channels
if i < (num_layers - 1):
resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels))
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
else:
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
@@ -638,20 +628,19 @@ class UpBlock1D(nn.Module):
class UpBlock1DNoSkip(nn.Module):
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 2):
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 3):
super().__init__()
if num_layers < 1:
raise ValueError("UpBlock1D requires added num_layers >= 1")
mid_channels = in_channels if mid_channels is None else mid_channels
resnets = [
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
]
resnets = []
for i in range(num_layers):
in_channels = 2 * in_channels if i == 0 else mid_channels
if i < (num_layers - 1):
resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels))
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
else:
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True))