1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

start kernelize.

This commit is contained in:
sayakpaul
2025-09-15 16:26:52 +05:30
parent f5c113e439
commit 50c0b786d2
3 changed files with 46 additions and 0 deletions

View File

@@ -20,6 +20,7 @@ import torch.nn.functional as F
from ..utils import deprecate, logging
from ..utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_xformers_available
from ..utils.kernels_utils import use_kernel_forward_from_hub
from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
from .attention_processor import Attention, AttentionProcessor, JointAttnProcessor2_0
@@ -1669,6 +1670,7 @@ class FreeNoiseTransformerBlock(nn.Module):
return hidden_states
@use_kernel_forward_from_hub("MLP")
class FeedForward(nn.Module):
r"""
A feed-forward layer.

View File

@@ -21,6 +21,7 @@ import torch.nn as nn
import torch.nn.functional as F
from ..utils import is_torch_npu_available, is_torch_version
from ..utils.kernels_utils import use_kernel_forward_from_hub
from .activations import get_activation
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
@@ -508,6 +509,7 @@ else:
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
@use_kernel_forward_from_hub("RMSNorm")
class RMSNorm(nn.Module):
r"""
RMS Norm as introduced in https://huggingface.co/papers/1910.07467 by Zhang et al.

View File

@@ -1,3 +1,5 @@
from typing import Union
from ..utils import get_logger
from .import_utils import is_kernels_available
@@ -21,3 +23,43 @@ def _get_fa3_from_hub():
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
raise
if is_kernels_available():
from kernels import (
Device,
LayerRepository,
register_kernel_mapping,
replace_kernel_forward_from_hub,
use_kernel_forward_from_hub,
)
_KERNEL_MAPPING: dict[str, dict[Union[Device, str], LayerRepository]] = {
"RMSNorm": {
"cuda": LayerRepository(repo_id="kernels-community/liger_kernels", layer_name="LigerRMSNorm"),
},
"MLP": {"cuda": LayerRepository(repo_id="medmekk/triton-llama-mlp", layer_name="TritonLlamaMLP")},
}
register_kernel_mapping(_KERNEL_MAPPING)
else:
# Stub to make decorators int transformers work when `kernels`
# is not installed.
def use_kernel_forward_from_hub(*args, **kwargs):
def decorator(cls):
return cls
return decorator
class LayerRepository:
def __init__(self, *args, **kwargs):
raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.")
def replace_kernel_forward_from_hub(*args, **kwargs):
raise RuntimeError(
"replace_kernel_forward_from_hub requires `kernels` to be installed. Run `pip install kernels`."
)
def register_kernel_mapping(*args, **kwargs):
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")