mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
add layers args
This commit is contained in:
@@ -43,6 +43,8 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to :
|
||||
obj:`(32, 32, 64)`): Tuple of block output channels.
|
||||
up_down_block_layers (`int`, defaults to 2): number of resnet, attention, or other layers in the up and down blocks.
|
||||
mid_block_layers (`int`, defaults to 5): number of resnet, attention, or other layers in the mid block.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
@@ -61,6 +63,8 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
mid_block_type: str = "UNetMidBlock1D",
|
||||
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
|
||||
block_out_channels: Tuple[int] = (32, 32, 64),
|
||||
up_down_block_layers: int = 2,
|
||||
mid_block_layers: int = 5,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -98,6 +102,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
down_block_type,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=up_down_block_layers,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
@@ -107,6 +112,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
mid_channels=block_out_channels[-1],
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=None,
|
||||
num_layers=mid_block_layers,
|
||||
)
|
||||
|
||||
# up
|
||||
@@ -120,6 +126,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
up_block_type,
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=up_down_block_layers,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
@@ -178,29 +178,29 @@ class ResConvBlock(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
def get_down_block(down_block_type, out_channels, in_channels):
|
||||
def get_down_block(down_block_type, num_layers, out_channels, in_channels):
|
||||
if down_block_type == "DownBlock1D":
|
||||
return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
|
||||
return DownBlock1D(out_channels=out_channels, in_channels=in_channels, num_layers=num_layers)
|
||||
elif down_block_type == "AttnDownBlock1D":
|
||||
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
|
||||
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels, num_layers=num_layers)
|
||||
elif down_block_type == "DownBlock1DNoSkip":
|
||||
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
|
||||
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels, num_layers=num_layers)
|
||||
raise ValueError(f"{down_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_up_block(up_block_type, in_channels, out_channels):
|
||||
def get_up_block(up_block_type, num_layers, in_channels, out_channels):
|
||||
if up_block_type == "UpBlock1D":
|
||||
return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
|
||||
return UpBlock1D(in_channels=in_channels, out_channels=out_channels, num_layers=num_layers)
|
||||
elif up_block_type == "AttnUpBlock1D":
|
||||
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
|
||||
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels, num_layers=num_layers)
|
||||
elif up_block_type == "UpBlock1DNoSkip":
|
||||
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
|
||||
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels, num_layers=num_layers)
|
||||
raise ValueError(f"{up_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_mid_block(mid_block_type, in_channels, mid_channels, out_channels):
|
||||
def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels):
|
||||
if mid_block_type == "UNetMidBlock1D":
|
||||
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
|
||||
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels, num_layers=num_layers)
|
||||
raise ValueError(f"{mid_block_type} does not exist.")
|
||||
|
||||
|
||||
@@ -283,17 +283,24 @@ class AttnDownBlock1D(nn.Module):
|
||||
|
||||
|
||||
class DownBlock1D(nn.Module):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: int = None):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: int = None, num_layers: int = 2):
|
||||
super().__init__()
|
||||
if num_layers < 1:
|
||||
raise ValueError(f"DownBlock1D requires added num_layers >= 1")
|
||||
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
self.down = KernelDownsample1D("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
|
||||
for i in range(num_layers):
|
||||
if i < (num_layers - 1):
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels))
|
||||
else:
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None):
|
||||
@@ -306,16 +313,23 @@ class DownBlock1D(nn.Module):
|
||||
|
||||
|
||||
class DownBlock1DNoSkip(nn.Module):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: int = None):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: int = None, num_layers: int = 2):
|
||||
super().__init__()
|
||||
if num_layers < 1:
|
||||
raise ValueError(f"DownBlock1DNoSkip requires added num_layers >= 1")
|
||||
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
|
||||
for i in range(num_layers):
|
||||
if i < (num_layers - 1):
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels))
|
||||
else:
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None):
|
||||
@@ -327,21 +341,28 @@ class DownBlock1DNoSkip(nn.Module):
|
||||
|
||||
|
||||
class AttnUpBlock1D(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 2):
|
||||
super().__init__()
|
||||
if num_layers < 1:
|
||||
raise ValueError(f"AttnUpBlock1D requires added num_layers >= 1")
|
||||
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
attentions = [
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(out_channels, out_channels // 32),
|
||||
]
|
||||
|
||||
for i in range(num_layers):
|
||||
if i < (num_layers - 1):
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels))
|
||||
attentions.append(SelfAttention1d(mid_channels, mid_channels // 32))
|
||||
else:
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
|
||||
attentions.append(SelfAttention1d(out_channels, out_channels // 32))
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.up = KernelUpsample1D(kernel="cubic")
|
||||
@@ -360,16 +381,23 @@ class AttnUpBlock1D(nn.Module):
|
||||
|
||||
|
||||
class UpBlock1D(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 2):
|
||||
super().__init__()
|
||||
if num_layers < 1:
|
||||
raise ValueError(f"UpBlock1D requires added num_layers >= 1")
|
||||
|
||||
mid_channels = in_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
|
||||
for i in range(num_layers):
|
||||
if i < (num_layers - 1):
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels))
|
||||
else:
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.up = KernelUpsample1D(kernel="cubic")
|
||||
|
||||
@@ -386,16 +414,23 @@ class UpBlock1D(nn.Module):
|
||||
|
||||
|
||||
class UpBlock1DNoSkip(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: int = None, num_layers: int = 2):
|
||||
super().__init__()
|
||||
if num_layers < 1:
|
||||
raise ValueError(f"UpBlock1D requires added num_layers >= 1")
|
||||
|
||||
mid_channels = in_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True),
|
||||
]
|
||||
|
||||
for i in range(num_layers):
|
||||
if i < (num_layers - 1):
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, mid_channels))
|
||||
else:
|
||||
resnets.append(ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple):
|
||||
|
||||
Reference in New Issue
Block a user