mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
start kernelize.
This commit is contained in:
@@ -20,6 +20,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_xformers_available
|
||||
from ..utils.kernels_utils import use_kernel_forward_from_hub
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
|
||||
from .attention_processor import Attention, AttentionProcessor, JointAttnProcessor2_0
|
||||
@@ -1669,6 +1670,7 @@ class FreeNoiseTransformerBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("MLP")
|
||||
class FeedForward(nn.Module):
|
||||
r"""
|
||||
A feed-forward layer.
|
||||
|
||||
@@ -21,6 +21,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils import is_torch_npu_available, is_torch_version
|
||||
from ..utils.kernels_utils import use_kernel_forward_from_hub
|
||||
from .activations import get_activation
|
||||
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
|
||||
@@ -508,6 +509,7 @@ else:
|
||||
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("RMSNorm")
|
||||
class RMSNorm(nn.Module):
|
||||
r"""
|
||||
RMS Norm as introduced in https://huggingface.co/papers/1910.07467 by Zhang et al.
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Union
|
||||
|
||||
from ..utils import get_logger
|
||||
from .import_utils import is_kernels_available
|
||||
|
||||
@@ -21,3 +23,43 @@ def _get_fa3_from_hub():
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if is_kernels_available():
|
||||
from kernels import (
|
||||
Device,
|
||||
LayerRepository,
|
||||
register_kernel_mapping,
|
||||
replace_kernel_forward_from_hub,
|
||||
use_kernel_forward_from_hub,
|
||||
)
|
||||
|
||||
_KERNEL_MAPPING: dict[str, dict[Union[Device, str], LayerRepository]] = {
|
||||
"RMSNorm": {
|
||||
"cuda": LayerRepository(repo_id="kernels-community/liger_kernels", layer_name="LigerRMSNorm"),
|
||||
},
|
||||
"MLP": {"cuda": LayerRepository(repo_id="medmekk/triton-llama-mlp", layer_name="TritonLlamaMLP")},
|
||||
}
|
||||
|
||||
register_kernel_mapping(_KERNEL_MAPPING)
|
||||
|
||||
else:
|
||||
# Stub to make decorators int transformers work when `kernels`
|
||||
# is not installed.
|
||||
def use_kernel_forward_from_hub(*args, **kwargs):
|
||||
def decorator(cls):
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
class LayerRepository:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.")
|
||||
|
||||
def replace_kernel_forward_from_hub(*args, **kwargs):
|
||||
raise RuntimeError(
|
||||
"replace_kernel_forward_from_hub requires `kernels` to be installed. Run `pip install kernels`."
|
||||
)
|
||||
|
||||
def register_kernel_mapping(*args, **kwargs):
|
||||
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")
|
||||
|
||||
Reference in New Issue
Block a user