From f1cb807496ab456c52b48574171aaa83902fab6d Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 30 Jun 2022 12:24:47 +0200 Subject: [PATCH] remove get_act --- .../models/unet_sde_score_estimation.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 6f909dcf3b..6eed6791d0 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -295,21 +295,6 @@ def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale return conv -def get_act(nonlinearity): - """Get activation functions from the config file.""" - - if nonlinearity.lower() == "elu": - return nn.ELU() - elif nonlinearity.lower() == "relu": - return nn.ReLU() - elif nonlinearity.lower() == "lrelu": - return nn.LeakyReLU(negative_slope=0.2) - elif nonlinearity.lower() == "swish": - return nn.SiLU() - else: - raise NotImplementedError("activation function does not exist!") - - def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"): """Ported from JAX.""" scale = 1e-10 if scale == 0 else scale @@ -467,7 +452,7 @@ class NCSNpp(ModelMixin, ConfigMixin): skip_rescale=skip_rescale, continuous=continuous, ) - self.act = act = get_act(nonlinearity) + self.act = act = nn.SiLU() self.nf = nf self.num_res_blocks = num_res_blocks