diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 4132ccbd0c..770043f053 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -462,7 +462,7 @@ class AttnDownBlock2D(nn.Module): self.downsamplers = nn.ModuleList( [ Downsample2D( - in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) @@ -546,7 +546,7 @@ class CrossAttnDownBlock2D(nn.Module): self.downsamplers = nn.ModuleList( [ Downsample2D( - in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) @@ -651,7 +651,7 @@ class DownBlock2D(nn.Module): self.downsamplers = nn.ModuleList( [ Downsample2D( - in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) @@ -729,7 +729,7 @@ class DownEncoderBlock2D(nn.Module): self.downsamplers = nn.ModuleList( [ Downsample2D( - in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) @@ -801,7 +801,7 @@ class AttnDownEncoderBlock2D(nn.Module): self.downsamplers = nn.ModuleList( [ Downsample2D( - in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) @@ -886,7 +886,7 @@ class AttnSkipDownBlock2D(nn.Module): down=True, kernel="fir", ) - self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)]) + self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) else: self.resnet_down = None @@ -966,7 +966,7 @@ class SkipDownBlock2D(nn.Module): down=True, kernel="fir", ) - self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)]) + self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) else: self.resnet_down = None