mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Move chroma layers into transformer
This commit is contained in:
@@ -40,16 +40,123 @@ from ..embeddings import (
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import (
|
||||
AdaLayerNormContinuousPruned,
|
||||
AdaLayerNormZeroPruned,
|
||||
AdaLayerNormZeroSinglePruned,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class ChromaAdaLayerNormZeroPruned(nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm zero (adaLN-Zero).
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
num_embeddings (`int`): The size of the embeddings dictionary.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
|
||||
super().__init__()
|
||||
if num_embeddings is not None:
|
||||
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
|
||||
else:
|
||||
self.emb = None
|
||||
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||
elif norm_type == "fp32_layer_norm":
|
||||
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
hidden_dtype: Optional[torch.dtype] = None,
|
||||
emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if self.emb is not None:
|
||||
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.squeeze(0).chunk(6, dim=0)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
class ChromaAdaLayerNormZeroSinglePruned(nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm zero (adaLN-Zero).
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
num_embeddings (`int`): The size of the embeddings dictionary.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
|
||||
super().__init__()
|
||||
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
shift_msa, scale_msa, gate_msa = emb.squeeze(0).chunk(3, dim=0)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
return x, gate_msa
|
||||
|
||||
|
||||
class ChromaAdaLayerNormContinuousPruned(nn.Module):
|
||||
r"""
|
||||
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
|
||||
|
||||
Args:
|
||||
embedding_dim (`int`): Embedding dimension to use during projection.
|
||||
conditioning_embedding_dim (`int`): Dimension of the input condition.
|
||||
elementwise_affine (`bool`, defaults to `True`):
|
||||
Boolean flag to denote if affine transformation should be applied.
|
||||
eps (`float`, defaults to 1e-5): Epsilon factor.
|
||||
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
|
||||
norm_type (`str`, defaults to `"layer_norm"`):
|
||||
Normalization layer to use. Values supported: "layer_norm", "rms_norm".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
conditioning_embedding_dim: int,
|
||||
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
|
||||
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
|
||||
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
|
||||
# However, this is how it was implemented in the original code, and it's rather likely you should
|
||||
# set `elementwise_affine` to False.
|
||||
elementwise_affine=True,
|
||||
eps=1e-5,
|
||||
bias=True,
|
||||
norm_type="layer_norm",
|
||||
):
|
||||
super().__init__()
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
||||
elif norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
|
||||
else:
|
||||
raise ValueError(f"unknown norm_type {norm_type}")
|
||||
|
||||
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
||||
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
|
||||
shift, scale = torch.chunk(emb.squeeze(0).to(x.dtype), 2, dim=0)
|
||||
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
return x
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class ChromaSingleTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
@@ -61,7 +168,7 @@ class ChromaSingleTransformerBlock(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.norm = AdaLayerNormZeroSinglePruned(dim)
|
||||
self.norm = ChromaAdaLayerNormZeroSinglePruned(dim)
|
||||
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
||||
self.act_mlp = nn.GELU(approximate="tanh")
|
||||
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
||||
@@ -127,8 +234,8 @@ class ChromaTransformerBlock(nn.Module):
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = AdaLayerNormZeroPruned(dim)
|
||||
self.norm1_context = AdaLayerNormZeroPruned(dim)
|
||||
self.norm1 = ChromaAdaLayerNormZeroPruned(dim)
|
||||
self.norm1_context = ChromaAdaLayerNormZeroPruned(dim)
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
@@ -298,8 +405,7 @@ class ChromaTransformer2DModel(
|
||||
]
|
||||
)
|
||||
|
||||
norm_out_cls = AdaLayerNormContinuousPruned
|
||||
self.norm_out = norm_out_cls(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.norm_out = ChromaAdaLayerNormContinuousPruned(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
Reference in New Issue
Block a user