diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index c720b37955..81e42509f9 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -20,6 +20,7 @@ import torch.nn.functional as F from ..utils import deprecate, logging from ..utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_xformers_available +from ..utils.kernels_utils import use_kernel_forward_from_hub from ..utils.torch_utils import maybe_allow_in_graph from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU from .attention_processor import Attention, AttentionProcessor, JointAttnProcessor2_0 @@ -1669,6 +1670,7 @@ class FreeNoiseTransformerBlock(nn.Module): return hidden_states +@use_kernel_forward_from_hub("MLP") class FeedForward(nn.Module): r""" A feed-forward layer. diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index ae2a6298f5..1d92dd5c60 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -21,6 +21,7 @@ import torch.nn as nn import torch.nn.functional as F from ..utils import is_torch_npu_available, is_torch_version +from ..utils.kernels_utils import use_kernel_forward_from_hub from .activations import get_activation from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings @@ -508,6 +509,7 @@ else: return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps) +@use_kernel_forward_from_hub("RMSNorm") class RMSNorm(nn.Module): r""" RMS Norm as introduced in https://huggingface.co/papers/1910.07467 by Zhang et al. diff --git a/src/diffusers/utils/kernels_utils.py b/src/diffusers/utils/kernels_utils.py index 26d6e3972f..5468657a3a 100644 --- a/src/diffusers/utils/kernels_utils.py +++ b/src/diffusers/utils/kernels_utils.py @@ -1,3 +1,5 @@ +from typing import Union + from ..utils import get_logger from .import_utils import is_kernels_available @@ -21,3 +23,43 @@ def _get_fa3_from_hub(): except Exception as e: logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}") raise + + +if is_kernels_available(): + from kernels import ( + Device, + LayerRepository, + register_kernel_mapping, + replace_kernel_forward_from_hub, + use_kernel_forward_from_hub, + ) + + _KERNEL_MAPPING: dict[str, dict[Union[Device, str], LayerRepository]] = { + "RMSNorm": { + "cuda": LayerRepository(repo_id="kernels-community/liger_kernels", layer_name="LigerRMSNorm"), + }, + "MLP": {"cuda": LayerRepository(repo_id="medmekk/triton-llama-mlp", layer_name="TritonLlamaMLP")}, + } + + register_kernel_mapping(_KERNEL_MAPPING) + +else: + # Stub to make decorators int transformers work when `kernels` + # is not installed. + def use_kernel_forward_from_hub(*args, **kwargs): + def decorator(cls): + return cls + + return decorator + + class LayerRepository: + def __init__(self, *args, **kwargs): + raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.") + + def replace_kernel_forward_from_hub(*args, **kwargs): + raise RuntimeError( + "replace_kernel_forward_from_hub requires `kernels` to be installed. Run `pip install kernels`." + ) + + def register_kernel_mapping(*args, **kwargs): + raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")