mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
remove silu for CogView4 (#12150)
* CogView4: remove SiLU in final AdaLN (match Megatron); add switch to AdaLayerNormContinuous; split temb_raw/temb_blocks * CogView4: remove SiLU in final AdaLN (match Megatron); add switch to AdaLayerNormContinuous; split temb_raw/temb_blocks * CogView4: remove SiLU in final AdaLN (match Megatron); add switch to AdaLayerNormContinuous; split temb_raw/temb_blocks * CogView4: use local final AdaLN (no SiLU) per review; keep generic AdaLN unchanged * re-add configs as normal files (no LFS) * Apply suggestions from code review * Apply style fixes --------- Co-authored-by: 武嘉涵 <lambert@wujiahandeMacBook-Pro.local> Co-authored-by: Aryan <contact.aryanvs@gmail.com> Co-authored-by: Aryan <aryan@huggingface.co> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -28,7 +28,7 @@ from ..cache_utils import CacheMixin
|
||||
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous
|
||||
from ..normalization import LayerNorm, RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -584,6 +584,38 @@ class CogView4RotaryPosEmbed(nn.Module):
|
||||
return (freqs.cos(), freqs.sin())
|
||||
|
||||
|
||||
class CogView4AdaLayerNormContinuous(nn.Module):
|
||||
"""
|
||||
CogView4-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the
|
||||
Linear on conditioning embedding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
conditioning_embedding_dim: int,
|
||||
elementwise_affine: bool = True,
|
||||
eps: float = 1e-5,
|
||||
bias: bool = True,
|
||||
norm_type: str = "layer_norm",
|
||||
):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
||||
elif norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
|
||||
else:
|
||||
raise ValueError(f"unknown norm_type {norm_type}")
|
||||
|
||||
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
||||
# *** NO SiLU here ***
|
||||
emb = self.linear(conditioning_embedding.to(x.dtype))
|
||||
scale, shift = torch.chunk(emb, 2, dim=1)
|
||||
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
return x
|
||||
|
||||
|
||||
class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
|
||||
r"""
|
||||
Args:
|
||||
@@ -666,7 +698,7 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
||||
)
|
||||
|
||||
# 4. Output projection
|
||||
self.norm_out = AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
|
||||
self.norm_out = CogView4AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
Reference in New Issue
Block a user