1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
sayakpaul
2025-09-18 14:49:48 +05:30
parent 58743c3ee7
commit 33a8a3be0c
2 changed files with 39 additions and 32 deletions

View File

@@ -17,8 +17,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from ..utils import deprecate, get_logger, is_kernels_available, is_torch_npu_available, is_torch_version
from ..utils.constants import DIFFUSERS_ENABLE_HUB_KERNELS
from ..utils import deprecate, get_logger, is_torch_npu_available, is_torch_version
logger = get_logger(__name__)
@@ -93,36 +92,24 @@ class GELU(nn.Module):
return hidden_states
class CUDAOptimizedGELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
if not torch.cuda.is_available():
raise NotImplementedError(f"{self.__class__.__name__} is implemented only for CUDA devices.")
if not DIFFUSERS_ENABLE_HUB_KERNELS:
raise RuntimeError(
f"{self.__class__.__name__} isn't usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
)
if not is_kernels_available():
raise NotImplementedError(
f"{self.__class__.__name__} requires the `kernels` library to be installed. Install it with `pip install kernels`."
)
# TODO: validation checks / consider making Python classes of activations like `transformers`
# All of these are temporary for now.
class CUDAOptimizedGELU(GELU):
def __init__(self, *args, **kwargs):
from kernels import get_kernel
super().__init__()
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
activations = get_kernel(KERNELS_REPO_ID)
if approximate == "tanh":
self.act = activations.gelu_tanh_and_mul
elif approximate == "none":
self.act = activations.gelu_and_mul
else:
raise NotImplementedError
activation = get_kernel("kernels-community/activation", revision="add_more_act")
approximate = kwargs.get("approximate", "none")
if approximate == "none":
self.act_fn = activation.gelu
elif approximate == "tanh":
self.act_fn = activation.gelu_tanh
super().__init__(*args, **kwargs)
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
out = torch.empty_like(hidden_states)
output = self.act(out, hidden_states)
return output
hidden_states = self.act_fn(hidden_states)
return hidden_states
class GEGLU(nn.Module):

View File

@@ -20,12 +20,20 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from ..utils import is_torch_npu_available, is_torch_version
from ..utils import is_kernels_available, is_torch_npu_available, is_torch_version
from ..utils.constants import DIFFUSERS_ENABLE_HUB_KERNELS
from ..utils.kernels_utils import use_kernel_forward_from_hub
from .activations import get_activation
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
if is_kernels_available() and DIFFUSERS_ENABLE_HUB_KERNELS:
from kernels import get_kernel
activation = get_kernel("kernels-community/activation", revision="add_more_act")
silu_kernel = activation.silu
class AdaLayerNorm(nn.Module):
r"""
Norm layer modified to incorporate timestep embeddings.
@@ -58,7 +66,10 @@ class AdaLayerNorm(nn.Module):
else:
self.emb = None
self.silu = nn.SiLU()
if not DIFFUSERS_ENABLE_HUB_KERNELS:
self.silu = nn.SiLU()
else:
self.silu = silu_kernel
self.linear = nn.Linear(embedding_dim, output_dim)
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
@@ -145,7 +156,10 @@ class AdaLayerNormZero(nn.Module):
else:
self.emb = None
self.silu = nn.SiLU()
if not DIFFUSERS_ENABLE_HUB_KERNELS:
self.silu = nn.SiLU()
else:
self.silu = silu_kernel
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
@@ -184,7 +198,10 @@ class AdaLayerNormZeroSingle(nn.Module):
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
super().__init__()
self.silu = nn.SiLU()
if not DIFFUSERS_ENABLE_HUB_KERNELS:
self.silu = nn.SiLU()
else:
self.silu = silu_kernel
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
@@ -336,7 +353,10 @@ class AdaLayerNormContinuous(nn.Module):
norm_type="layer_norm",
):
super().__init__()
self.silu = nn.SiLU()
if not DIFFUSERS_ENABLE_HUB_KERNELS:
self.silu = nn.SiLU()
else:
self.silu = silu_kernel
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
if norm_type == "layer_norm":
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)