diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index 45ef3df36d..18ee107a34 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -434,7 +434,6 @@ class UNetMidBlock1D(nn.Module): out_channels = in_channels if out_channels is None else out_channels - # there is always at least one resnet self.down = KernelDownsample1D("cubic") resnets = [ ResConvBlock(in_channels, mid_channels, mid_channels), @@ -469,7 +468,7 @@ class UNetMidBlock1D(nn.Module): class AttnDownBlock1D(nn.Module): - def __init__(self, out_channels: int, in_channels: int, num_layers: int = 2, mid_channels: int = None): + def __init__(self, out_channels: int, in_channels: int, num_layers: int = 3, mid_channels: int = None): super().__init__() if num_layers < 1: @@ -478,16 +477,13 @@ class AttnDownBlock1D(nn.Module): mid_channels = out_channels if mid_channels is None else mid_channels self.down = KernelDownsample1D("cubic") - resnets = [ - ResConvBlock(in_channels, mid_channels, mid_channels), - ] - attentions = [ - SelfAttention1d(mid_channels, mid_channels // 32), - ] + resnets = [] + attentions = [] for i in range(num_layers): + in_channels = in_channels if i == 0 else mid_channels if i < (num_layers - 1): - resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels)) + resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels)) attentions.append(SelfAttention1d(mid_channels, mid_channels // 32)) else: resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels)) @@ -507,7 +503,7 @@ class AttnDownBlock1D(nn.Module): class DownBlock1D(nn.Module): - def __init__(self, out_channels: int, in_channels: int, mid_channels: int = None, num_layers: int = 2): + def __init__(self, out_channels: int, in_channels: int, mid_channels: int = None, num_layers: int = 3): super().__init__() if num_layers < 1: raise ValueError("DownBlock1D requires added num_layers >= 1") @@ -515,13 +511,12 @@ class DownBlock1D(nn.Module): mid_channels = out_channels if mid_channels is None else mid_channels self.down = KernelDownsample1D("cubic") - resnets = [ - ResConvBlock(in_channels, mid_channels, mid_channels), - ] + resnets = [] for i in range(num_layers): + in_channels = in_channels if i == 0 else mid_channels if i < (num_layers - 1): - resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels)) + resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels)) else: resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels)) @@ -537,20 +532,19 @@ class DownBlock1D(nn.Module): class DownBlock1DNoSkip(nn.Module): - def __init__(self, out_channels: int, in_channels: int, mid_channels: int = None, num_layers: int = 2): + def __init__(self, out_channels: int, in_channels: int, mid_channels: int = None, num_layers: int = 3): super().__init__() if num_layers < 1: raise ValueError("DownBlock1DNoSkip requires added num_layers >= 1") mid_channels = out_channels if mid_channels is None else mid_channels - resnets = [ - ResConvBlock(in_channels, mid_channels, mid_channels), - ] + resnets = [] for i in range(num_layers): + in_channels = in_channels if i == 0 else mid_channels if i < (num_layers - 1): - resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels)) + resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels)) else: resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels)) @@ -565,23 +559,20 @@ class DownBlock1DNoSkip(nn.Module): class AttnUpBlock1D(nn.Module): - def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 2): + def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 3): super().__init__() if num_layers < 1: raise ValueError("AttnUpBlock1D requires added num_layers >= 1") mid_channels = out_channels if mid_channels is None else mid_channels - resnets = [ - ResConvBlock(2 * in_channels, mid_channels, mid_channels), - ] - attentions = [ - SelfAttention1d(mid_channels, mid_channels // 32), - ] + resnets = [] + attentions = [] for i in range(num_layers): + in_channels = 2 * in_channels if i == 0 else mid_channels if i < (num_layers - 1): - resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels)) + resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels)) attentions.append(SelfAttention1d(mid_channels, mid_channels // 32)) else: resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels)) @@ -605,20 +596,19 @@ class AttnUpBlock1D(nn.Module): class UpBlock1D(nn.Module): - def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 2): + def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 3): super().__init__() if num_layers < 1: raise ValueError("UpBlock1D requires added num_layers >= 1") mid_channels = in_channels if mid_channels is None else mid_channels - resnets = [ - ResConvBlock(2 * in_channels, mid_channels, mid_channels), - ] + resnets = [] for i in range(num_layers): + in_channels = 2 * in_channels if i == 0 else mid_channels if i < (num_layers - 1): - resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels)) + resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels)) else: resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels)) @@ -638,20 +628,19 @@ class UpBlock1D(nn.Module): class UpBlock1DNoSkip(nn.Module): - def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 2): + def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 3): super().__init__() if num_layers < 1: raise ValueError("UpBlock1D requires added num_layers >= 1") mid_channels = in_channels if mid_channels is None else mid_channels - resnets = [ - ResConvBlock(2 * in_channels, mid_channels, mid_channels), - ] + resnets = [] for i in range(num_layers): + in_channels = 2 * in_channels if i == 0 else mid_channels if i < (num_layers - 1): - resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels)) + resnets.append(ResConvBlock(in_channels, mid_channels, mid_channels)) else: resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True))