From 5df4c8b81f04ff2fa578ae3c54a9d20c2458b76b Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 26 Oct 2022 15:56:04 -0700 Subject: [PATCH] add layers args --- src/diffusers/models/unet_1d.py | 7 ++ src/diffusers/models/unet_1d_blocks.py | 89 ++++++++++++++++++-------- 2 files changed, 69 insertions(+), 27 deletions(-) diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index cc0685deb9..37c3189312 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -43,6 +43,8 @@ class UNet1DModel(ModelMixin, ConfigMixin): obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to : obj:`(32, 32, 64)`): Tuple of block output channels. + up_down_block_layers (`int`, defaults to 2): number of resnet, attention, or other layers in the up and down blocks. + mid_block_layers (`int`, defaults to 5): number of resnet, attention, or other layers in the mid block. """ @register_to_config @@ -61,6 +63,8 @@ class UNet1DModel(ModelMixin, ConfigMixin): mid_block_type: str = "UNetMidBlock1D", up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"), block_out_channels: Tuple[int] = (32, 32, 64), + up_down_block_layers: int = 2, + mid_block_layers: int = 5, ): super().__init__() @@ -98,6 +102,7 @@ class UNet1DModel(ModelMixin, ConfigMixin): down_block_type, in_channels=input_channel, out_channels=output_channel, + num_layers=up_down_block_layers, ) self.down_blocks.append(down_block) @@ -107,6 +112,7 @@ class UNet1DModel(ModelMixin, ConfigMixin): mid_channels=block_out_channels[-1], in_channels=block_out_channels[-1], out_channels=None, + num_layers=mid_block_layers, ) # up @@ -120,6 +126,7 @@ class UNet1DModel(ModelMixin, ConfigMixin): up_block_type, in_channels=prev_output_channel, out_channels=output_channel, + num_layers=up_down_block_layers, ) self.up_blocks.append(up_block) prev_output_channel = output_channel diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index 0b64155382..71a9b568c1 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -178,29 +178,29 @@ class ResConvBlock(nn.Module): return output -def get_down_block(down_block_type, out_channels, in_channels): +def get_down_block(down_block_type, num_layers, out_channels, in_channels): if down_block_type == "DownBlock1D": - return DownBlock1D(out_channels=out_channels, in_channels=in_channels) + return DownBlock1D(out_channels=out_channels, in_channels=in_channels, num_layers=num_layers) elif down_block_type == "AttnDownBlock1D": - return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels) + return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels, num_layers=num_layers) elif down_block_type == "DownBlock1DNoSkip": - return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels) + return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels, num_layers=num_layers) raise ValueError(f"{down_block_type} does not exist.") -def get_up_block(up_block_type, in_channels, out_channels): +def get_up_block(up_block_type, num_layers, in_channels, out_channels): if up_block_type == "UpBlock1D": - return UpBlock1D(in_channels=in_channels, out_channels=out_channels) + return UpBlock1D(in_channels=in_channels, out_channels=out_channels, num_layers=num_layers) elif up_block_type == "AttnUpBlock1D": - return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels) + return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels, num_layers=num_layers) elif up_block_type == "UpBlock1DNoSkip": - return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels) + return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels, num_layers=num_layers) raise ValueError(f"{up_block_type} does not exist.") -def get_mid_block(mid_block_type, in_channels, mid_channels, out_channels): +def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels): if mid_block_type == "UNetMidBlock1D": - return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels) + return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels, num_layers=num_layers) raise ValueError(f"{mid_block_type} does not exist.") @@ -283,17 +283,24 @@ class AttnDownBlock1D(nn.Module): class DownBlock1D(nn.Module): - def __init__(self, out_channels: int, in_channels: int, mid_channels: int = None): + def __init__(self, out_channels: int, in_channels: int, mid_channels: int = None, num_layers: int = 2): super().__init__() + if num_layers < 1: + raise ValueError(f"DownBlock1D requires added num_layers >= 1") + mid_channels = out_channels if mid_channels is None else mid_channels self.down = KernelDownsample1D("cubic") resnets = [ ResConvBlock(in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels), ] + for i in range(num_layers): + if i < (num_layers - 1): + resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels)) + else: + resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels)) + self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states, temb=None): @@ -306,16 +313,23 @@ class DownBlock1D(nn.Module): class DownBlock1DNoSkip(nn.Module): - def __init__(self, out_channels: int, in_channels: int, mid_channels: int = None): + def __init__(self, out_channels: int, in_channels: int, mid_channels: int = None, num_layers: int = 2): super().__init__() + if num_layers < 1: + raise ValueError(f"DownBlock1DNoSkip requires added num_layers >= 1") + mid_channels = out_channels if mid_channels is None else mid_channels resnets = [ ResConvBlock(in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels), ] + for i in range(num_layers): + if i < (num_layers - 1): + resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels)) + else: + resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels)) + self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states, temb=None): @@ -327,21 +341,28 @@ class DownBlock1DNoSkip(nn.Module): class AttnUpBlock1D(nn.Module): - def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None): + def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 2): super().__init__() + if num_layers < 1: + raise ValueError(f"AttnUpBlock1D requires added num_layers >= 1") + mid_channels = out_channels if mid_channels is None else mid_channels resnets = [ ResConvBlock(2 * in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels), ] attentions = [ SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(mid_channels, mid_channels // 32), - SelfAttention1d(out_channels, out_channels // 32), ] + for i in range(num_layers): + if i < (num_layers - 1): + resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels)) + attentions.append(SelfAttention1d(mid_channels, mid_channels // 32)) + else: + resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels)) + attentions.append(SelfAttention1d(out_channels, out_channels // 32)) + self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.up = KernelUpsample1D(kernel="cubic") @@ -360,16 +381,23 @@ class AttnUpBlock1D(nn.Module): class UpBlock1D(nn.Module): - def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None): + def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 2): super().__init__() + if num_layers < 1: + raise ValueError(f"UpBlock1D requires added num_layers >= 1") + mid_channels = in_channels if mid_channels is None else mid_channels resnets = [ ResConvBlock(2 * in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels), ] + for i in range(num_layers): + if i < (num_layers - 1): + resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels)) + else: + resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels)) + self.resnets = nn.ModuleList(resnets) self.up = KernelUpsample1D(kernel="cubic") @@ -386,16 +414,23 @@ class UpBlock1D(nn.Module): class UpBlock1DNoSkip(nn.Module): - def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None): + def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 2): super().__init__() + if num_layers < 1: + raise ValueError(f"UpBlock1D requires added num_layers >= 1") + mid_channels = in_channels if mid_channels is None else mid_channels resnets = [ ResConvBlock(2 * in_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, mid_channels), - ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True), ] + for i in range(num_layers): + if i < (num_layers - 1): + resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels)) + else: + resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True)) + self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states, res_hidden_states_tuple):