mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
resnet skip time activation and output scale factor
This commit is contained in:
committed by
Will Berman
parent
26b4319ac5
commit
707341aebe
@@ -459,6 +459,7 @@ class ResnetBlock2D(nn.Module):
|
||||
pre_norm=True,
|
||||
eps=1e-6,
|
||||
non_linearity="swish",
|
||||
skip_time_act=False,
|
||||
time_embedding_norm="default", # default, scale_shift, ada_group
|
||||
kernel=None,
|
||||
output_scale_factor=1.0,
|
||||
@@ -479,6 +480,7 @@ class ResnetBlock2D(nn.Module):
|
||||
self.down = down
|
||||
self.output_scale_factor = output_scale_factor
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.skip_time_act = skip_time_act
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
@@ -570,7 +572,9 @@ class ResnetBlock2D(nn.Module):
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
if self.time_emb_proj is not None:
|
||||
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
|
||||
if not self.skip_time_act:
|
||||
temb = self.nonlinearity(temb)
|
||||
temb = self.time_emb_proj(temb)[:, :, None, None]
|
||||
|
||||
if temb is not None and self.time_embedding_norm == "default":
|
||||
hidden_states = hidden_states + temb
|
||||
|
||||
@@ -42,6 +42,8 @@ def get_down_block(
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
resnet_time_scale_shift="default",
|
||||
resnet_skip_time_act=False,
|
||||
resnet_out_scale_factor=1.0,
|
||||
):
|
||||
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
||||
if down_block_type == "DownBlock2D":
|
||||
@@ -68,6 +70,8 @@ def get_down_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
skip_time_act=resnet_skip_time_act,
|
||||
output_scale_factor=resnet_out_scale_factor,
|
||||
)
|
||||
elif down_block_type == "AttnDownBlock2D":
|
||||
return AttnDownBlock2D(
|
||||
@@ -119,6 +123,8 @@ def get_down_block(
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
skip_time_act=resnet_skip_time_act,
|
||||
output_scale_factor=resnet_out_scale_factor,
|
||||
)
|
||||
elif down_block_type == "SkipDownBlock2D":
|
||||
return SkipDownBlock2D(
|
||||
@@ -214,6 +220,8 @@ def get_up_block(
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
resnet_time_scale_shift="default",
|
||||
resnet_skip_time_act=False,
|
||||
resnet_out_scale_factor=1.0,
|
||||
):
|
||||
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
||||
if up_block_type == "UpBlock2D":
|
||||
@@ -241,6 +249,8 @@ def get_up_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
skip_time_act=resnet_skip_time_act,
|
||||
output_scale_factor=resnet_out_scale_factor,
|
||||
)
|
||||
elif up_block_type == "CrossAttnUpBlock2D":
|
||||
if cross_attention_dim is None:
|
||||
@@ -279,6 +289,8 @@ def get_up_block(
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
skip_time_act=resnet_skip_time_act,
|
||||
output_scale_factor=resnet_out_scale_factor,
|
||||
)
|
||||
elif up_block_type == "AttnUpBlock2D":
|
||||
return AttnUpBlock2D(
|
||||
@@ -562,6 +574,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
||||
attn_num_head_channels=1,
|
||||
output_scale_factor=1.0,
|
||||
cross_attention_dim=1280,
|
||||
skip_time_act=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -585,6 +598,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
skip_time_act=skip_time_act,
|
||||
)
|
||||
]
|
||||
attentions = []
|
||||
@@ -615,6 +629,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
skip_time_act=skip_time_act,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1247,6 +1262,7 @@ class ResnetDownsampleBlock2D(nn.Module):
|
||||
resnet_pre_norm: bool = True,
|
||||
output_scale_factor=1.0,
|
||||
add_downsample=True,
|
||||
skip_time_act=False,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -1265,6 +1281,7 @@ class ResnetDownsampleBlock2D(nn.Module):
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
skip_time_act=skip_time_act,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1284,6 +1301,7 @@ class ResnetDownsampleBlock2D(nn.Module):
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
skip_time_act=skip_time_act,
|
||||
down=True,
|
||||
)
|
||||
]
|
||||
@@ -1337,6 +1355,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
cross_attention_dim=1280,
|
||||
output_scale_factor=1.0,
|
||||
add_downsample=True,
|
||||
skip_time_act=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -1362,6 +1381,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
skip_time_act=skip_time_act,
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
@@ -1394,6 +1414,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
skip_time_act=skip_time_act,
|
||||
down=True,
|
||||
)
|
||||
]
|
||||
@@ -2237,6 +2258,7 @@ class ResnetUpsampleBlock2D(nn.Module):
|
||||
resnet_pre_norm: bool = True,
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
skip_time_act=False,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -2257,6 +2279,7 @@ class ResnetUpsampleBlock2D(nn.Module):
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
skip_time_act=skip_time_act,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -2276,6 +2299,7 @@ class ResnetUpsampleBlock2D(nn.Module):
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
skip_time_act=skip_time_act,
|
||||
up=True,
|
||||
)
|
||||
]
|
||||
@@ -2329,6 +2353,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
cross_attention_dim=1280,
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
skip_time_act=False,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -2355,6 +2380,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
skip_time_act=skip_time_act,
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
@@ -2387,6 +2413,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
skip_time_act=skip_time_act,
|
||||
up=True,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -146,6 +146,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_skip_time_act: bool = False,
|
||||
resnet_out_scale_factor: int = 1.0,
|
||||
time_embedding_type: str = "positional",
|
||||
timestep_post_act: Optional[str] = None,
|
||||
time_cond_proj_dim: Optional[int] = None,
|
||||
@@ -291,6 +293,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resnet_skip_time_act=resnet_skip_time_act,
|
||||
resnet_out_scale_factor=resnet_out_scale_factor,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
@@ -321,6 +325,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
skip_time_act=resnet_skip_time_act,
|
||||
)
|
||||
elif mid_block_type is None:
|
||||
self.mid_block = None
|
||||
@@ -369,6 +374,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resnet_skip_time_act=resnet_skip_time_act,
|
||||
resnet_out_scale_factor=resnet_out_scale_factor,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
@@ -232,6 +232,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_skip_time_act: bool = False,
|
||||
resnet_out_scale_factor: int = 1.0,
|
||||
time_embedding_type: str = "positional",
|
||||
timestep_post_act: Optional[str] = None,
|
||||
time_cond_proj_dim: Optional[int] = None,
|
||||
@@ -382,6 +384,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resnet_skip_time_act=resnet_skip_time_act,
|
||||
resnet_out_scale_factor=resnet_out_scale_factor,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
@@ -412,6 +416,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
skip_time_act=resnet_skip_time_act,
|
||||
)
|
||||
elif mid_block_type is None:
|
||||
self.mid_block = None
|
||||
@@ -460,6 +465,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resnet_skip_time_act=resnet_skip_time_act,
|
||||
resnet_out_scale_factor=resnet_out_scale_factor,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
@@ -1434,6 +1441,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
|
||||
attn_num_head_channels=1,
|
||||
output_scale_factor=1.0,
|
||||
cross_attention_dim=1280,
|
||||
skip_time_act=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -1457,6 +1465,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
skip_time_act=skip_time_act,
|
||||
)
|
||||
]
|
||||
attentions = []
|
||||
@@ -1487,6 +1496,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
skip_time_act=skip_time_act,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user