From 92199ff3ac91817ba8d1ec6e20f4256595078807 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 22 Sep 2025 16:46:49 +0530 Subject: [PATCH] up --- src/diffusers/models/normalization.py | 24 +++++++-------- .../models/transformers/transformer_flux.py | 29 ++++++++++++++----- 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 9943e9f6c8..84e2de0830 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -66,10 +66,10 @@ class AdaLayerNorm(nn.Module): else: self.emb = None - if not DIFFUSERS_ENABLE_HUB_KERNELS: - self.silu = nn.SiLU() - else: + if DIFFUSERS_ENABLE_HUB_KERNELS: self.silu = silu_kernel() + else: + self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, output_dim) self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) @@ -156,10 +156,10 @@ class AdaLayerNormZero(nn.Module): else: self.emb = None - if not DIFFUSERS_ENABLE_HUB_KERNELS: - self.silu = nn.SiLU() - else: + if DIFFUSERS_ENABLE_HUB_KERNELS: self.silu = silu_kernel() + else: + self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias) if norm_type == "layer_norm": self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) @@ -198,10 +198,10 @@ class AdaLayerNormZeroSingle(nn.Module): def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): super().__init__() - if not DIFFUSERS_ENABLE_HUB_KERNELS: - self.silu = nn.SiLU() - else: + if DIFFUSERS_ENABLE_HUB_KERNELS: self.silu = silu_kernel() + else: + self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias) if norm_type == "layer_norm": self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) @@ -353,10 +353,10 @@ class AdaLayerNormContinuous(nn.Module): norm_type="layer_norm", ): super().__init__() - if not DIFFUSERS_ENABLE_HUB_KERNELS: - self.silu = nn.SiLU() - else: + if DIFFUSERS_ENABLE_HUB_KERNELS: self.silu = silu_kernel() + else: + self.silu = nn.SiLU() self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) if norm_type == "layer_norm": self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 73cba6c9de..6bfc052f6a 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -307,8 +307,14 @@ class FluxAttention(torch.nn.Module, AttentionModuleMixin): self.added_kv_proj_dim = added_kv_proj_dim self.added_proj_bias = added_proj_bias - self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) - self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + if DIFFUSERS_ENABLE_HUB_KERNELS: + from ..normalization import RMSNorm + + self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + else: + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) @@ -319,8 +325,14 @@ class FluxAttention(torch.nn.Module, AttentionModuleMixin): self.to_out.append(torch.nn.Dropout(dropout)) if added_kv_proj_dim is not None: - self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) - self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) + if DIFFUSERS_ENABLE_HUB_KERNELS: + from ..normalization import RMSNorm + + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + else: + self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) + self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) @@ -357,10 +369,11 @@ class FluxSingleTransformerBlock(nn.Module): self.norm = AdaLayerNormZeroSingle(dim) self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) - if not DIFFUSERS_ENABLE_HUB_KERNELS: - self.act_mlp = nn.GELU(approximate="tanh") - else: - self.act_mlp = gelu_tanh_kernel() + self.act_mlp = nn.GELU(approximate="tanh") + # if not DIFFUSERS_ENABLE_HUB_KERNELS: + # self.act_mlp = nn.GELU(approximate="tanh") + # else: + # self.act_mlp = gelu_tanh_kernel() self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)