1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

resnet skip time activation and output scale factor

This commit is contained in:
William Berman
2023-04-08 18:12:46 -07:00
committed by Will Berman
parent 26b4319ac5
commit 707341aebe
4 changed files with 49 additions and 1 deletions

View File

@@ -459,6 +459,7 @@ class ResnetBlock2D(nn.Module):
pre_norm=True,
eps=1e-6,
non_linearity="swish",
skip_time_act=False,
time_embedding_norm="default", # default, scale_shift, ada_group
kernel=None,
output_scale_factor=1.0,
@@ -479,6 +480,7 @@ class ResnetBlock2D(nn.Module):
self.down = down
self.output_scale_factor = output_scale_factor
self.time_embedding_norm = time_embedding_norm
self.skip_time_act = skip_time_act
if groups_out is None:
groups_out = groups
@@ -570,7 +572,9 @@ class ResnetBlock2D(nn.Module):
hidden_states = self.conv1(hidden_states)
if self.time_emb_proj is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
if not self.skip_time_act:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb)[:, :, None, None]
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb

View File

@@ -42,6 +42,8 @@ def get_down_block(
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock2D":
@@ -68,6 +70,8 @@ def get_down_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
)
elif down_block_type == "AttnDownBlock2D":
return AttnDownBlock2D(
@@ -119,6 +123,8 @@ def get_down_block(
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
)
elif down_block_type == "SkipDownBlock2D":
return SkipDownBlock2D(
@@ -214,6 +220,8 @@ def get_up_block(
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock2D":
@@ -241,6 +249,8 @@ def get_up_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
)
elif up_block_type == "CrossAttnUpBlock2D":
if cross_attention_dim is None:
@@ -279,6 +289,8 @@ def get_up_block(
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor,
)
elif up_block_type == "AttnUpBlock2D":
return AttnUpBlock2D(
@@ -562,6 +574,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
attn_num_head_channels=1,
output_scale_factor=1.0,
cross_attention_dim=1280,
skip_time_act=False,
):
super().__init__()
@@ -585,6 +598,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
]
attentions = []
@@ -615,6 +629,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
)
@@ -1247,6 +1262,7 @@ class ResnetDownsampleBlock2D(nn.Module):
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_downsample=True,
skip_time_act=False,
):
super().__init__()
resnets = []
@@ -1265,6 +1281,7 @@ class ResnetDownsampleBlock2D(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
)
@@ -1284,6 +1301,7 @@ class ResnetDownsampleBlock2D(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
down=True,
)
]
@@ -1337,6 +1355,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
cross_attention_dim=1280,
output_scale_factor=1.0,
add_downsample=True,
skip_time_act=False,
):
super().__init__()
@@ -1362,6 +1381,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
)
attentions.append(
@@ -1394,6 +1414,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
down=True,
)
]
@@ -2237,6 +2258,7 @@ class ResnetUpsampleBlock2D(nn.Module):
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_upsample=True,
skip_time_act=False,
):
super().__init__()
resnets = []
@@ -2257,6 +2279,7 @@ class ResnetUpsampleBlock2D(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
)
@@ -2276,6 +2299,7 @@ class ResnetUpsampleBlock2D(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
up=True,
)
]
@@ -2329,6 +2353,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
cross_attention_dim=1280,
output_scale_factor=1.0,
add_upsample=True,
skip_time_act=False,
):
super().__init__()
resnets = []
@@ -2355,6 +2380,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
)
attentions.append(
@@ -2387,6 +2413,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
up=True,
)
]

View File

@@ -146,6 +146,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: int = 1.0,
time_embedding_type: str = "positional",
timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None,
@@ -291,6 +293,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
)
self.down_blocks.append(down_block)
@@ -321,6 +325,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
)
elif mid_block_type is None:
self.mid_block = None
@@ -369,6 +374,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel

View File

@@ -232,6 +232,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: int = 1.0,
time_embedding_type: str = "positional",
timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None,
@@ -382,6 +384,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
)
self.down_blocks.append(down_block)
@@ -412,6 +416,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
)
elif mid_block_type is None:
self.mid_block = None
@@ -460,6 +465,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
@@ -1434,6 +1441,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
attn_num_head_channels=1,
output_scale_factor=1.0,
cross_attention_dim=1280,
skip_time_act=False,
):
super().__init__()
@@ -1457,6 +1465,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
]
attentions = []
@@ -1487,6 +1496,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
skip_time_act=skip_time_act,
)
)