diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index e66d90040f..8b75162ba5 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -21,6 +21,15 @@ from ..utils import USE_PEFT_BACKEND from .lora import LoRACompatibleLinear +ACTIVATION_FUNCTIONS = { + "swish": nn.SiLU(), + "silu": nn.SiLU(), + "mish": nn.Mish(), + "gelu": nn.GELU(), + "relu": nn.ReLU(), +} + + def get_activation(act_fn: str) -> nn.Module: """Helper function to get activation function from string. @@ -30,14 +39,10 @@ def get_activation(act_fn: str) -> nn.Module: Returns: nn.Module: Activation function. """ - if act_fn in ["swish", "silu"]: - return nn.SiLU() - elif act_fn == "mish": - return nn.Mish() - elif act_fn == "gelu": - return nn.GELU() - elif act_fn == "relu": - return nn.ReLU() + + act_fn = act_fn.lower() + if act_fn in ACTIVATION_FUNCTIONS: + return ACTIVATION_FUNCTIONS[act_fn] else: raise ValueError(f"Unsupported activation function: {act_fn}")