mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
up
This commit is contained in:
@@ -22,7 +22,8 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, is_kernels_available, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.constants import DIFFUSERS_ENABLE_HUB_KERNELS
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -40,6 +41,12 @@ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNo
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_kernels_available() and DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
from kernels import get_kernel
|
||||
|
||||
activation = get_kernel("kernels-community/activation", revision="add_more_act")
|
||||
gelu_tanh_kernel = activation.gelu_tanh
|
||||
|
||||
|
||||
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
||||
query = attn.to_q(hidden_states)
|
||||
@@ -350,7 +357,11 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
|
||||
self.norm = AdaLayerNormZeroSingle(dim)
|
||||
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
||||
self.act_mlp = nn.GELU(approximate="tanh")
|
||||
if not DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
self.act_mlp = nn.GELU(approximate="tanh")
|
||||
else:
|
||||
self.act_mlp = gelu_tanh_kernel
|
||||
|
||||
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
||||
|
||||
self.attn = FluxAttention(
|
||||
|
||||
Reference in New Issue
Block a user