mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Add unet act fn to other model components (#3136)
Adding act fn config to the unet timestep class embedding and conv activation. The custom activation defaults to silu which is the default activation function for both the conv act and the timestep class embeddings so default behavior is not changed. The only unet which use the custom activation is the stable diffusion latent upscaler https://huggingface.co/stabilityai/sd-x2-latent-upscaler/blob/main/unet/config.json (I ran a script against the hub to confirm). The latent upscaler does not use the conv activation nor the timestep class embeddings so we don't change its behavior.
This commit is contained in:
@@ -248,7 +248,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
elif class_embed_type == "timestep":
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
||||
elif class_embed_type == "identity":
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
elif class_embed_type == "projection":
|
||||
@@ -437,7 +437,18 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
||||
)
|
||||
self.conv_act = nn.SiLU()
|
||||
|
||||
if act_fn == "swish":
|
||||
self.conv_act = lambda x: F.silu(x)
|
||||
elif act_fn == "mish":
|
||||
self.conv_act = nn.Mish()
|
||||
elif act_fn == "silu":
|
||||
self.conv_act = nn.SiLU()
|
||||
elif act_fn == "gelu":
|
||||
self.conv_act = nn.GELU()
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation function: {act_fn}")
|
||||
|
||||
else:
|
||||
self.conv_norm_out = None
|
||||
self.conv_act = None
|
||||
|
||||
@@ -345,7 +345,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
elif class_embed_type == "timestep":
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
||||
elif class_embed_type == "identity":
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
elif class_embed_type == "projection":
|
||||
@@ -534,7 +534,18 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
||||
)
|
||||
self.conv_act = nn.SiLU()
|
||||
|
||||
if act_fn == "swish":
|
||||
self.conv_act = lambda x: F.silu(x)
|
||||
elif act_fn == "mish":
|
||||
self.conv_act = nn.Mish()
|
||||
elif act_fn == "silu":
|
||||
self.conv_act = nn.SiLU()
|
||||
elif act_fn == "gelu":
|
||||
self.conv_act = nn.GELU()
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation function: {act_fn}")
|
||||
|
||||
else:
|
||||
self.conv_norm_out = None
|
||||
self.conv_act = None
|
||||
|
||||
Reference in New Issue
Block a user