mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
remove get_act
This commit is contained in:
@@ -295,21 +295,6 @@ def conv2d(in_planes, out_planes, kernel_size=3, stride=1, bias=True, init_scale
|
||||
return conv
|
||||
|
||||
|
||||
def get_act(nonlinearity):
|
||||
"""Get activation functions from the config file."""
|
||||
|
||||
if nonlinearity.lower() == "elu":
|
||||
return nn.ELU()
|
||||
elif nonlinearity.lower() == "relu":
|
||||
return nn.ReLU()
|
||||
elif nonlinearity.lower() == "lrelu":
|
||||
return nn.LeakyReLU(negative_slope=0.2)
|
||||
elif nonlinearity.lower() == "swish":
|
||||
return nn.SiLU()
|
||||
else:
|
||||
raise NotImplementedError("activation function does not exist!")
|
||||
|
||||
|
||||
def variance_scaling(scale=1.0, in_axis=1, out_axis=0, dtype=torch.float32, device="cpu"):
|
||||
"""Ported from JAX."""
|
||||
scale = 1e-10 if scale == 0 else scale
|
||||
@@ -467,7 +452,7 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
skip_rescale=skip_rescale,
|
||||
continuous=continuous,
|
||||
)
|
||||
self.act = act = get_act(nonlinearity)
|
||||
self.act = act = nn.SiLU()
|
||||
|
||||
self.nf = nf
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
||||
Reference in New Issue
Block a user