1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Add unet act fn to other model components (#3136)

Adding act fn config to the unet timestep class embedding and conv
activation.

The custom activation defaults to silu which is the default
activation function for both the conv act and the timestep class
embeddings so default behavior is not changed.

The only unet which use the custom activation is the stable diffusion
latent upscaler https://huggingface.co/stabilityai/sd-x2-latent-upscaler/blob/main/unet/config.json
(I ran a script against the hub to confirm).
The latent upscaler does not use the conv activation nor the timestep
class embeddings so we don't change its behavior.
This commit is contained in:
Will Berman
2023-04-18 14:13:16 -07:00
committed by Daniel Gu
parent 00a5e55b50
commit ff5b99b81f
2 changed files with 26 additions and 4 deletions

View File

@@ -248,7 +248,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
elif class_embed_type == "projection":
@@ -437,7 +437,18 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
)
self.conv_act = nn.SiLU()
if act_fn == "swish":
self.conv_act = lambda x: F.silu(x)
elif act_fn == "mish":
self.conv_act = nn.Mish()
elif act_fn == "silu":
self.conv_act = nn.SiLU()
elif act_fn == "gelu":
self.conv_act = nn.GELU()
else:
raise ValueError(f"Unsupported activation function: {act_fn}")
else:
self.conv_norm_out = None
self.conv_act = None

View File

@@ -345,7 +345,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
elif class_embed_type == "projection":
@@ -534,7 +534,18 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
)
self.conv_act = nn.SiLU()
if act_fn == "swish":
self.conv_act = lambda x: F.silu(x)
elif act_fn == "mish":
self.conv_act = nn.Mish()
elif act_fn == "silu":
self.conv_act = nn.SiLU()
elif act_fn == "gelu":
self.conv_act = nn.GELU()
else:
raise ValueError(f"Unsupported activation function: {act_fn}")
else:
self.conv_norm_out = None
self.conv_act = None