From ff5b99b81f67260b60023063a2c763b005b4c217 Mon Sep 17 00:00:00 2001 From: Will Berman Date: Tue, 18 Apr 2023 14:13:16 -0700 Subject: [PATCH] Add unet act fn to other model components (#3136) Adding act fn config to the unet timestep class embedding and conv activation. The custom activation defaults to silu which is the default activation function for both the conv act and the timestep class embeddings so default behavior is not changed. The only unet which use the custom activation is the stable diffusion latent upscaler https://huggingface.co/stabilityai/sd-x2-latent-upscaler/blob/main/unet/config.json (I ran a script against the hub to confirm). The latent upscaler does not use the conv activation nor the timestep class embeddings so we don't change its behavior. --- src/diffusers/models/unet_2d_condition.py | 15 +++++++++++++-- .../versatile_diffusion/modeling_text_unet.py | 15 +++++++++++++-- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index b281435693..29de8734d4 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -248,7 +248,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": @@ -437,7 +437,18 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) - self.conv_act = nn.SiLU() + + if act_fn == "swish": + self.conv_act = lambda x: F.silu(x) + elif act_fn == "mish": + self.conv_act = nn.Mish() + elif act_fn == "silu": + self.conv_act = nn.SiLU() + elif act_fn == "gelu": + self.conv_act = nn.GELU() + else: + raise ValueError(f"Unsupported activation function: {act_fn}") + else: self.conv_norm_out = None self.conv_act = None diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 4377be1181..b20f18c485 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -345,7 +345,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": @@ -534,7 +534,18 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) - self.conv_act = nn.SiLU() + + if act_fn == "swish": + self.conv_act = lambda x: F.silu(x) + elif act_fn == "mish": + self.conv_act = nn.Mish() + elif act_fn == "silu": + self.conv_act = nn.SiLU() + elif act_fn == "gelu": + self.conv_act = nn.GELU() + else: + raise ValueError(f"Unsupported activation function: {act_fn}") + else: self.conv_norm_out = None self.conv_act = None