mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
fix weird layer counting
This commit is contained in:
@@ -434,7 +434,6 @@ class UNetMidBlock1D(nn.Module):
|
||||
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
|
||||
# there is always at least one resnet
|
||||
self.down = KernelDownsample1D("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
@@ -469,7 +468,7 @@ class UNetMidBlock1D(nn.Module):
|
||||
|
||||
|
||||
class AttnDownBlock1D(nn.Module):
|
||||
def __init__(self, out_channels: int, in_channels: int, num_layers: int = 2, mid_channels: int = None):
|
||||
def __init__(self, out_channels: int, in_channels: int, num_layers: int = 3, mid_channels: int = None):
|
||||
super().__init__()
|
||||
|
||||
if num_layers < 1:
|
||||
@@ -478,16 +477,13 @@ class AttnDownBlock1D(nn.Module):
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
self.down = KernelDownsample1D("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
]
|
||||
attentions = [
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
]
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else mid_channels
|
||||
if i < (num_layers - 1):
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels))
|
||||
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
|
||||
attentions.append(SelfAttention1d(mid_channels, mid_channels // 32))
|
||||
else:
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
|
||||
@@ -507,7 +503,7 @@ class AttnDownBlock1D(nn.Module):
|
||||
|
||||
|
||||
class DownBlock1D(nn.Module):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: int = None, num_layers: int = 2):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: int = None, num_layers: int = 3):
|
||||
super().__init__()
|
||||
if num_layers < 1:
|
||||
raise ValueError("DownBlock1D requires added num_layers >= 1")
|
||||
@@ -515,13 +511,12 @@ class DownBlock1D(nn.Module):
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
self.down = KernelDownsample1D("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
]
|
||||
resnets = []
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else mid_channels
|
||||
if i < (num_layers - 1):
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels))
|
||||
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
|
||||
else:
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
|
||||
|
||||
@@ -537,20 +532,19 @@ class DownBlock1D(nn.Module):
|
||||
|
||||
|
||||
class DownBlock1DNoSkip(nn.Module):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: int = None, num_layers: int = 2):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: int = None, num_layers: int = 3):
|
||||
super().__init__()
|
||||
if num_layers < 1:
|
||||
raise ValueError("DownBlock1DNoSkip requires added num_layers >= 1")
|
||||
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
]
|
||||
resnets = []
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else mid_channels
|
||||
if i < (num_layers - 1):
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels))
|
||||
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
|
||||
else:
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
|
||||
|
||||
@@ -565,23 +559,20 @@ class DownBlock1DNoSkip(nn.Module):
|
||||
|
||||
|
||||
class AttnUpBlock1D(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 2):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 3):
|
||||
super().__init__()
|
||||
if num_layers < 1:
|
||||
raise ValueError("AttnUpBlock1D requires added num_layers >= 1")
|
||||
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
]
|
||||
attentions = [
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
]
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = 2 * in_channels if i == 0 else mid_channels
|
||||
if i < (num_layers - 1):
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels))
|
||||
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
|
||||
attentions.append(SelfAttention1d(mid_channels, mid_channels // 32))
|
||||
else:
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
|
||||
@@ -605,20 +596,19 @@ class AttnUpBlock1D(nn.Module):
|
||||
|
||||
|
||||
class UpBlock1D(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 2):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 3):
|
||||
super().__init__()
|
||||
if num_layers < 1:
|
||||
raise ValueError("UpBlock1D requires added num_layers >= 1")
|
||||
|
||||
mid_channels = in_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
]
|
||||
resnets = []
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = 2 * in_channels if i == 0 else mid_channels
|
||||
if i < (num_layers - 1):
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels))
|
||||
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
|
||||
else:
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
|
||||
|
||||
@@ -638,20 +628,19 @@ class UpBlock1D(nn.Module):
|
||||
|
||||
|
||||
class UpBlock1DNoSkip(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 2):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 3):
|
||||
super().__init__()
|
||||
if num_layers < 1:
|
||||
raise ValueError("UpBlock1D requires added num_layers >= 1")
|
||||
|
||||
mid_channels = in_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
]
|
||||
resnets = []
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = 2 * in_channels if i == 0 else mid_channels
|
||||
if i < (num_layers - 1):
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels))
|
||||
resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels))
|
||||
else:
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user