1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

remove silu for CogView4 (#12150)

* CogView4: remove SiLU in final AdaLN (match Megatron); add  switch to AdaLayerNormContinuous; split temb_raw/temb_blocks

* CogView4: remove SiLU in final AdaLN (match Megatron); add  switch to AdaLayerNormContinuous; split temb_raw/temb_blocks

* CogView4: remove SiLU in final AdaLN (match Megatron); add  switch to AdaLayerNormContinuous; split temb_raw/temb_blocks

* CogView4: use local final AdaLN (no SiLU) per review; keep generic AdaLN unchanged

* re-add configs as normal files (no LFS)

* Apply suggestions from code review

* Apply style fixes

---------

Co-authored-by: 武嘉涵 <lambert@wujiahandeMacBook-Pro.local>
Co-authored-by: Aryan <contact.aryanvs@gmail.com>
Co-authored-by: Aryan <aryan@huggingface.co>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Lambert
2025-08-18 10:32:01 +08:00
committed by GitHub
parent e682af2027
commit 76c809e2ef

View File

@@ -28,7 +28,7 @@ from ..cache_utils import CacheMixin
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous
from ..normalization import LayerNorm, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -584,6 +584,38 @@ class CogView4RotaryPosEmbed(nn.Module):
return (freqs.cos(), freqs.sin())
class CogView4AdaLayerNormContinuous(nn.Module):
"""
CogView4-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the
Linear on conditioning embedding.
"""
def __init__(
self,
embedding_dim: int,
conditioning_embedding_dim: int,
elementwise_affine: bool = True,
eps: float = 1e-5,
bias: bool = True,
norm_type: str = "layer_norm",
):
super().__init__()
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
if norm_type == "layer_norm":
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
elif norm_type == "rms_norm":
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
else:
raise ValueError(f"unknown norm_type {norm_type}")
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
# *** NO SiLU here ***
emb = self.linear(conditioning_embedding.to(x.dtype))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
r"""
Args:
@@ -666,7 +698,7 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
)
# 4. Output projection
self.norm_out = AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
self.norm_out = CogView4AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True)
self.gradient_checkpointing = False