From ce7f334472abf0bca282d27a7d519226828d7647 Mon Sep 17 00:00:00 2001 From: Chi Date: Thu, 26 Oct 2023 09:36:30 +0530 Subject: [PATCH] Remove multiple if-else statement in the get_activation function. (#5446) * I added a new doc string to the class. This is more flexible to understanding other developers what are doing and where it's using. * Update src/diffusers/models/unet_2d_blocks.py This changes suggest by maintener. Co-authored-by: Sayak Paul * Update src/diffusers/models/unet_2d_blocks.py Add suggested text Co-authored-by: Sayak Paul * Update unet_2d_blocks.py I changed the Parameter to Args text. * Update unet_2d_blocks.py proper indentation set in this file. * Update unet_2d_blocks.py a little bit of change in the act_fun argument line. * I run the black command to reformat style in the code * Update unet_2d_blocks.py similar doc-string add to have in the original diffusion repository. * I use a lower method in the activation function. * Replace multiple if-else statements with a dictionary of activation functions, and call one if statement to retrieve the appropriate function. * I am using black package to reforamted my file * I defined the ACTIVATION_FUNCTIONS variable outside of the function * activation function variable convert to lower case * First, I resolved the conflict issue. Then, I ran the Black package to reformat my file. --------- Co-authored-by: Sayak Paul --- src/diffusers/models/activations.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index e66d90040f..8b75162ba5 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -21,6 +21,15 @@ from ..utils import USE_PEFT_BACKEND from .lora import LoRACompatibleLinear +ACTIVATION_FUNCTIONS = { + "swish": nn.SiLU(), + "silu": nn.SiLU(), + "mish": nn.Mish(), + "gelu": nn.GELU(), + "relu": nn.ReLU(), +} + + def get_activation(act_fn: str) -> nn.Module: """Helper function to get activation function from string. @@ -30,14 +39,10 @@ def get_activation(act_fn: str) -> nn.Module: Returns: nn.Module: Activation function. """ - if act_fn in ["swish", "silu"]: - return nn.SiLU() - elif act_fn == "mish": - return nn.Mish() - elif act_fn == "gelu": - return nn.GELU() - elif act_fn == "relu": - return nn.ReLU() + + act_fn = act_fn.lower() + if act_fn in ACTIVATION_FUNCTIONS: + return ACTIVATION_FUNCTIONS[act_fn] else: raise ValueError(f"Unsupported activation function: {act_fn}")