mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
move activation dispatches into helper function (#3656)
* move activation dispatches into helper function * tests
This commit is contained in:
12
src/diffusers/models/activations.py
Normal file
12
src/diffusers/models/activations.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from torch import nn
|
||||
|
||||
|
||||
def get_activation(act_fn):
|
||||
if act_fn in ["swish", "silu"]:
|
||||
return nn.SiLU()
|
||||
elif act_fn == "mish":
|
||||
return nn.Mish()
|
||||
elif act_fn == "gelu":
|
||||
return nn.GELU()
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation function: {act_fn}")
|
||||
@@ -18,6 +18,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import maybe_allow_in_graph
|
||||
from .activations import get_activation
|
||||
from .attention_processor import Attention
|
||||
from .embeddings import CombinedTimestepLabelEmbeddings
|
||||
|
||||
@@ -345,15 +346,11 @@ class AdaGroupNorm(nn.Module):
|
||||
super().__init__()
|
||||
self.num_groups = num_groups
|
||||
self.eps = eps
|
||||
self.act = None
|
||||
if act_fn == "swish":
|
||||
self.act = lambda x: F.silu(x)
|
||||
elif act_fn == "mish":
|
||||
self.act = nn.Mish()
|
||||
elif act_fn == "silu":
|
||||
self.act = nn.SiLU()
|
||||
elif act_fn == "gelu":
|
||||
self.act = nn.GELU()
|
||||
|
||||
if act_fn is None:
|
||||
self.act = None
|
||||
else:
|
||||
self.act = get_activation(act_fn)
|
||||
|
||||
self.linear = nn.Linear(embedding_dim, out_dim * 2)
|
||||
|
||||
|
||||
@@ -18,6 +18,8 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .activations import get_activation
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
@@ -171,14 +173,7 @@ class TimestepEmbedding(nn.Module):
|
||||
else:
|
||||
self.cond_proj = None
|
||||
|
||||
if act_fn == "silu":
|
||||
self.act = nn.SiLU()
|
||||
elif act_fn == "mish":
|
||||
self.act = nn.Mish()
|
||||
elif act_fn == "gelu":
|
||||
self.act = nn.GELU()
|
||||
else:
|
||||
raise ValueError(f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
|
||||
self.act = get_activation(act_fn)
|
||||
|
||||
if out_dim is not None:
|
||||
time_embed_dim_out = out_dim
|
||||
@@ -188,14 +183,8 @@ class TimestepEmbedding(nn.Module):
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
elif post_act_fn == "silu":
|
||||
self.post_act = nn.SiLU()
|
||||
elif post_act_fn == "mish":
|
||||
self.post_act = nn.Mish()
|
||||
elif post_act_fn == "gelu":
|
||||
self.post_act = nn.GELU()
|
||||
else:
|
||||
raise ValueError(f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
|
||||
self.post_act = get_activation(post_act_fn)
|
||||
|
||||
def forward(self, sample, condition=None):
|
||||
if condition is not None:
|
||||
|
||||
@@ -20,6 +20,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .activations import get_activation
|
||||
from .attention import AdaGroupNorm
|
||||
from .attention_processor import SpatialNorm
|
||||
|
||||
@@ -558,14 +559,7 @@ class ResnetBlock2D(nn.Module):
|
||||
conv_2d_out_channels = conv_2d_out_channels or out_channels
|
||||
self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if non_linearity == "swish":
|
||||
self.nonlinearity = lambda x: F.silu(x)
|
||||
elif non_linearity == "mish":
|
||||
self.nonlinearity = nn.Mish()
|
||||
elif non_linearity == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
elif non_linearity == "gelu":
|
||||
self.nonlinearity = nn.GELU()
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.upsample = self.downsample = None
|
||||
if self.up:
|
||||
@@ -646,11 +640,6 @@ class ResnetBlock2D(nn.Module):
|
||||
return output_tensor
|
||||
|
||||
|
||||
class Mish(torch.nn.Module):
|
||||
def forward(self, hidden_states):
|
||||
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
|
||||
|
||||
|
||||
# unet_rl.py
|
||||
def rearrange_dims(tensor):
|
||||
if len(tensor.shape) == 2:
|
||||
|
||||
@@ -17,6 +17,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .activations import get_activation
|
||||
from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
|
||||
|
||||
|
||||
@@ -55,14 +56,10 @@ class DownResnetBlock1D(nn.Module):
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity == "swish":
|
||||
self.nonlinearity = lambda x: F.silu(x)
|
||||
elif non_linearity == "mish":
|
||||
self.nonlinearity = nn.Mish()
|
||||
elif non_linearity == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
else:
|
||||
if non_linearity is None:
|
||||
self.nonlinearity = None
|
||||
else:
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.downsample = None
|
||||
if add_downsample:
|
||||
@@ -119,14 +116,10 @@ class UpResnetBlock1D(nn.Module):
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity == "swish":
|
||||
self.nonlinearity = lambda x: F.silu(x)
|
||||
elif non_linearity == "mish":
|
||||
self.nonlinearity = nn.Mish()
|
||||
elif non_linearity == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
else:
|
||||
if non_linearity is None:
|
||||
self.nonlinearity = None
|
||||
else:
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.upsample = None
|
||||
if add_upsample:
|
||||
@@ -194,14 +187,10 @@ class MidResTemporalBlock1D(nn.Module):
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity == "swish":
|
||||
self.nonlinearity = lambda x: F.silu(x)
|
||||
elif non_linearity == "mish":
|
||||
self.nonlinearity = nn.Mish()
|
||||
elif non_linearity == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
else:
|
||||
if non_linearity is None:
|
||||
self.nonlinearity = None
|
||||
else:
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.upsample = None
|
||||
if add_upsample:
|
||||
@@ -232,10 +221,7 @@ class OutConv1DBlock(nn.Module):
|
||||
super().__init__()
|
||||
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
|
||||
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
|
||||
if act_fn == "silu":
|
||||
self.final_conv1d_act = nn.SiLU()
|
||||
if act_fn == "mish":
|
||||
self.final_conv1d_act = nn.Mish()
|
||||
self.final_conv1d_act = get_activation(act_fn)
|
||||
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
|
||||
|
||||
def forward(self, hidden_states, temb=None):
|
||||
|
||||
@@ -16,12 +16,12 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import UNet2DConditionLoadersMixin
|
||||
from ..utils import BaseOutput, logging
|
||||
from .activations import get_activation
|
||||
from .attention_processor import AttentionProcessor, AttnProcessor
|
||||
from .embeddings import (
|
||||
GaussianFourierProjection,
|
||||
@@ -338,16 +338,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
|
||||
if time_embedding_act_fn is None:
|
||||
self.time_embed_act = None
|
||||
elif time_embedding_act_fn == "swish":
|
||||
self.time_embed_act = lambda x: F.silu(x)
|
||||
elif time_embedding_act_fn == "mish":
|
||||
self.time_embed_act = nn.Mish()
|
||||
elif time_embedding_act_fn == "silu":
|
||||
self.time_embed_act = nn.SiLU()
|
||||
elif time_embedding_act_fn == "gelu":
|
||||
self.time_embed_act = nn.GELU()
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")
|
||||
self.time_embed_act = get_activation(time_embedding_act_fn)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
@@ -501,16 +493,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
||||
)
|
||||
|
||||
if act_fn == "swish":
|
||||
self.conv_act = lambda x: F.silu(x)
|
||||
elif act_fn == "mish":
|
||||
self.conv_act = nn.Mish()
|
||||
elif act_fn == "silu":
|
||||
self.conv_act = nn.SiLU()
|
||||
elif act_fn == "gelu":
|
||||
self.conv_act = nn.GELU()
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation function: {act_fn}")
|
||||
self.conv_act = get_activation(act_fn)
|
||||
|
||||
else:
|
||||
self.conv_norm_out = None
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models import ModelMixin
|
||||
from ...models.activations import get_activation
|
||||
from ...models.attention import Attention
|
||||
from ...models.attention_processor import (
|
||||
AttentionProcessor,
|
||||
@@ -441,16 +442,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
|
||||
if time_embedding_act_fn is None:
|
||||
self.time_embed_act = None
|
||||
elif time_embedding_act_fn == "swish":
|
||||
self.time_embed_act = lambda x: F.silu(x)
|
||||
elif time_embedding_act_fn == "mish":
|
||||
self.time_embed_act = nn.Mish()
|
||||
elif time_embedding_act_fn == "silu":
|
||||
self.time_embed_act = nn.SiLU()
|
||||
elif time_embedding_act_fn == "gelu":
|
||||
self.time_embed_act = nn.GELU()
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")
|
||||
self.time_embed_act = get_activation(time_embedding_act_fn)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
@@ -604,16 +597,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
||||
)
|
||||
|
||||
if act_fn == "swish":
|
||||
self.conv_act = lambda x: F.silu(x)
|
||||
elif act_fn == "mish":
|
||||
self.conv_act = nn.Mish()
|
||||
elif act_fn == "silu":
|
||||
self.conv_act = nn.SiLU()
|
||||
elif act_fn == "gelu":
|
||||
self.conv_act = nn.GELU()
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation function: {act_fn}")
|
||||
self.conv_act = get_activation(act_fn)
|
||||
|
||||
else:
|
||||
self.conv_norm_out = None
|
||||
|
||||
48
tests/models/test_activations.py
Normal file
48
tests/models/test_activations.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from diffusers.models.activations import get_activation
|
||||
|
||||
|
||||
class ActivationsTests(unittest.TestCase):
|
||||
def test_swish(self):
|
||||
act = get_activation("swish")
|
||||
|
||||
self.assertIsInstance(act, nn.SiLU)
|
||||
|
||||
self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0)
|
||||
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0)
|
||||
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0)
|
||||
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20)
|
||||
|
||||
def test_silu(self):
|
||||
act = get_activation("silu")
|
||||
|
||||
self.assertIsInstance(act, nn.SiLU)
|
||||
|
||||
self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0)
|
||||
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0)
|
||||
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0)
|
||||
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20)
|
||||
|
||||
def test_mish(self):
|
||||
act = get_activation("mish")
|
||||
|
||||
self.assertIsInstance(act, nn.Mish)
|
||||
|
||||
self.assertEqual(act(torch.tensor(-200, dtype=torch.float32)).item(), 0)
|
||||
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0)
|
||||
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0)
|
||||
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20)
|
||||
|
||||
def test_gelu(self):
|
||||
act = get_activation("gelu")
|
||||
|
||||
self.assertIsInstance(act, nn.GELU)
|
||||
|
||||
self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0)
|
||||
self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0)
|
||||
self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0)
|
||||
self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20)
|
||||
Reference in New Issue
Block a user