diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 7ab371a1a1..d7a236462b 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -22,7 +22,8 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, is_kernels_available, logging, scale_lora_layers, unscale_lora_layers +from ...utils.constants import DIFFUSERS_ENABLE_HUB_KERNELS from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn @@ -40,6 +41,12 @@ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNo logger = logging.get_logger(__name__) # pylint: disable=invalid-name +if is_kernels_available() and DIFFUSERS_ENABLE_HUB_KERNELS: + from kernels import get_kernel + + activation = get_kernel("kernels-community/activation", revision="add_more_act") + gelu_tanh_kernel = activation.gelu_tanh + def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): query = attn.to_q(hidden_states) @@ -350,7 +357,11 @@ class FluxSingleTransformerBlock(nn.Module): self.norm = AdaLayerNormZeroSingle(dim) self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) - self.act_mlp = nn.GELU(approximate="tanh") + if not DIFFUSERS_ENABLE_HUB_KERNELS: + self.act_mlp = nn.GELU(approximate="tanh") + else: + self.act_mlp = gelu_tanh_kernel + self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) self.attn = FluxAttention(