1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

move activation dispatches into helper function (#3656)

* move activation dispatches into helper function

* tests
This commit is contained in:
Will Berman
2023-06-05 12:30:48 -07:00
committed by GitHub
parent 462956be7b
commit 41ae670828
8 changed files with 89 additions and 101 deletions

View File

@@ -0,0 +1,12 @@
from torch import nn
def get_activation(act_fn):
if act_fn in ["swish", "silu"]:
return nn.SiLU()
elif act_fn == "mish":
return nn.Mish()
elif act_fn == "gelu":
return nn.GELU()
else:
raise ValueError(f"Unsupported activation function: {act_fn}")

View File

@@ -18,6 +18,7 @@ import torch.nn.functional as F
from torch import nn
from ..utils import maybe_allow_in_graph
from .activations import get_activation
from .attention_processor import Attention
from .embeddings import CombinedTimestepLabelEmbeddings
@@ -345,15 +346,11 @@ class AdaGroupNorm(nn.Module):
super().__init__()
self.num_groups = num_groups
self.eps = eps
self.act = None
if act_fn == "swish":
self.act = lambda x: F.silu(x)
elif act_fn == "mish":
self.act = nn.Mish()
elif act_fn == "silu":
self.act = nn.SiLU()
elif act_fn == "gelu":
self.act = nn.GELU()
if act_fn is None:
self.act = None
else:
self.act = get_activation(act_fn)
self.linear = nn.Linear(embedding_dim, out_dim * 2)

View File

@@ -18,6 +18,8 @@ import numpy as np
import torch
from torch import nn
from .activations import get_activation
def get_timestep_embedding(
timesteps: torch.Tensor,
@@ -171,14 +173,7 @@ class TimestepEmbedding(nn.Module):
else:
self.cond_proj = None
if act_fn == "silu":
self.act = nn.SiLU()
elif act_fn == "mish":
self.act = nn.Mish()
elif act_fn == "gelu":
self.act = nn.GELU()
else:
raise ValueError(f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
self.act = get_activation(act_fn)
if out_dim is not None:
time_embed_dim_out = out_dim
@@ -188,14 +183,8 @@ class TimestepEmbedding(nn.Module):
if post_act_fn is None:
self.post_act = None
elif post_act_fn == "silu":
self.post_act = nn.SiLU()
elif post_act_fn == "mish":
self.post_act = nn.Mish()
elif post_act_fn == "gelu":
self.post_act = nn.GELU()
else:
raise ValueError(f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
self.post_act = get_activation(post_act_fn)
def forward(self, sample, condition=None):
if condition is not None:

View File

@@ -20,6 +20,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .activations import get_activation
from .attention import AdaGroupNorm
from .attention_processor import SpatialNorm
@@ -558,14 +559,7 @@ class ResnetBlock2D(nn.Module):
conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = nn.Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
elif non_linearity == "gelu":
self.nonlinearity = nn.GELU()
self.nonlinearity = get_activation(non_linearity)
self.upsample = self.downsample = None
if self.up:
@@ -646,11 +640,6 @@ class ResnetBlock2D(nn.Module):
return output_tensor
class Mish(torch.nn.Module):
def forward(self, hidden_states):
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
# unet_rl.py
def rearrange_dims(tensor):
if len(tensor.shape) == 2:

View File

@@ -17,6 +17,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from .activations import get_activation
from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
@@ -55,14 +56,10 @@ class DownResnetBlock1D(nn.Module):
self.resnets = nn.ModuleList(resnets)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = nn.Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
else:
if non_linearity is None:
self.nonlinearity = None
else:
self.nonlinearity = get_activation(non_linearity)
self.downsample = None
if add_downsample:
@@ -119,14 +116,10 @@ class UpResnetBlock1D(nn.Module):
self.resnets = nn.ModuleList(resnets)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = nn.Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
else:
if non_linearity is None:
self.nonlinearity = None
else:
self.nonlinearity = get_activation(non_linearity)
self.upsample = None
if add_upsample:
@@ -194,14 +187,10 @@ class MidResTemporalBlock1D(nn.Module):
self.resnets = nn.ModuleList(resnets)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = nn.Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
else:
if non_linearity is None:
self.nonlinearity = None
else:
self.nonlinearity = get_activation(non_linearity)
self.upsample = None
if add_upsample:
@@ -232,10 +221,7 @@ class OutConv1DBlock(nn.Module):
super().__init__()
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
if act_fn == "silu":
self.final_conv1d_act = nn.SiLU()
if act_fn == "mish":
self.final_conv1d_act = nn.Mish()
self.final_conv1d_act = get_activation(act_fn)
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
def forward(self, hidden_states, temb=None):

View File

@@ -16,12 +16,12 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging
from .activations import get_activation
from .attention_processor import AttentionProcessor, AttnProcessor
from .embeddings import (
GaussianFourierProjection,
@@ -338,16 +338,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if time_embedding_act_fn is None:
self.time_embed_act = None
elif time_embedding_act_fn == "swish":
self.time_embed_act = lambda x: F.silu(x)
elif time_embedding_act_fn == "mish":
self.time_embed_act = nn.Mish()
elif time_embedding_act_fn == "silu":
self.time_embed_act = nn.SiLU()
elif time_embedding_act_fn == "gelu":
self.time_embed_act = nn.GELU()
else:
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")
self.time_embed_act = get_activation(time_embedding_act_fn)
self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
@@ -501,16 +493,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
)
if act_fn == "swish":
self.conv_act = lambda x: F.silu(x)
elif act_fn == "mish":
self.conv_act = nn.Mish()
elif act_fn == "silu":
self.conv_act = nn.SiLU()
elif act_fn == "gelu":
self.conv_act = nn.GELU()
else:
raise ValueError(f"Unsupported activation function: {act_fn}")
self.conv_act = get_activation(act_fn)
else:
self.conv_norm_out = None

View File

@@ -7,6 +7,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin
from ...models.activations import get_activation
from ...models.attention import Attention
from ...models.attention_processor import (
AttentionProcessor,
@@ -441,16 +442,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if time_embedding_act_fn is None:
self.time_embed_act = None
elif time_embedding_act_fn == "swish":
self.time_embed_act = lambda x: F.silu(x)
elif time_embedding_act_fn == "mish":
self.time_embed_act = nn.Mish()
elif time_embedding_act_fn == "silu":
self.time_embed_act = nn.SiLU()
elif time_embedding_act_fn == "gelu":
self.time_embed_act = nn.GELU()
else:
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")
self.time_embed_act = get_activation(time_embedding_act_fn)
self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
@@ -604,16 +597,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
)
if act_fn == "swish":
self.conv_act = lambda x: F.silu(x)
elif act_fn == "mish":
self.conv_act = nn.Mish()
elif act_fn == "silu":
self.conv_act = nn.SiLU()
elif act_fn == "gelu":
self.conv_act = nn.GELU()
else:
raise ValueError(f"Unsupported activation function: {act_fn}")
self.conv_act = get_activation(act_fn)
else:
self.conv_norm_out = None

View File

@@ -0,0 +1,48 @@
import unittest
import torch
from torch import nn
from diffusers.models.activations import get_activation
class ActivationsTests(unittest.TestCase):
def test_swish(self):
act = get_activation("swish")
self.assertIsInstance(act, nn.SiLU)
self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0)
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0)
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0)
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20)
def test_silu(self):
act = get_activation("silu")
self.assertIsInstance(act, nn.SiLU)
self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0)
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0)
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0)
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20)
def test_mish(self):
act = get_activation("mish")
self.assertIsInstance(act, nn.Mish)
self.assertEqual(act(torch.tensor(-200, dtype=torch.float32)).item(), 0)
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0)
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0)
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20)
def test_gelu(self):
act = get_activation("gelu")
self.assertIsInstance(act, nn.GELU)
self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0)
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0)
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0)
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20)