1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Initial commit: Chroma as a FLUX.1 variant.

This commit is contained in:
Hameer Abbasi
2025-05-17 05:20:26 +05:00
parent 9836f0e000
commit 8ceed7d3ae
3 changed files with 274 additions and 26 deletions

View File

@@ -31,7 +31,7 @@ def get_timestep_embedding(
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
) -> torch.Tensor:
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
@@ -1327,7 +1327,7 @@ class Timesteps(nn.Module):
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
def forward(self, timesteps):
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
@@ -1637,6 +1637,50 @@ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
return conditioning
class CombinedTimestepTextProjChromaEmbeddings(nn.Module):
def __init__(self, factor: int, hidden_dim: int, out_dim: int, n_layers: int, embedding_dim: int):
super().__init__()
self.time_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
self.guidance_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
self.embedder = ChromaApproximator(
in_dim=factor * 4,
out_dim=out_dim,
hidden_dim=hidden_dim,
n_layers=n_layers,
)
self.embedding_dim = embedding_dim
self.register_buffer(
"mod_proj",
get_timestep_embedding(torch.arange(out_dim), 2 * factor, flip_sin_to_cos=True, downscale_freq_shift=0),
persistent=False,
)
def forward(
self, timestep: torch.Tensor, guidance: Optional[torch.Tensor], pooled_projections: torch.Tensor
) -> torch.Tensor:
mod_index_length = self.mod_proj.shape[0]
timesteps_proj = self.time_proj(timestep) + self.time_proj(pooled_projections)
if guidance is not None:
guidance_proj = self.guidance_proj(guidance)
else:
guidance_proj = torch.zeros(
(self.embedding_dim, self.guidance_proj.num_channels),
dtype=timesteps_proj.dtype,
device=timesteps_proj.device,
)
mod_proj = self.mod_proj.to(dtype=timesteps_proj.dtype, device=timesteps_proj.device)
timestep_guidance = (
torch.cat([timesteps_proj, guidance_proj], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1)
)
input_vec = torch.cat([timestep_guidance, mod_proj], dim=-1)
conditioning = self.embedder(input_vec)
return conditioning
class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
super().__init__()
@@ -2230,6 +2274,25 @@ class PixArtAlphaTextProjection(nn.Module):
return hidden_states
class ChromaApproximator(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers: int = 5):
super().__init__()
self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True)
self.layers = nn.ModuleList(
[PixArtAlphaTextProjection(hidden_dim, hidden_dim, act_fn="silu") for _ in range(n_layers)]
)
self.norms = nn.ModuleList([nn.RMSNorm(hidden_dim) for _ in range(n_layers)])
self.out_proj = nn.Linear(hidden_dim, out_dim)
def forward(self, x):
x = self.in_proj(x)
for layer, norms in zip(self.layers, self.norms):
x = x + layer(norms(x))
return self.out_proj(x)
class IPAdapterPlusImageProjectionBlock(nn.Module):
def __init__(
self,

View File

@@ -171,6 +171,46 @@ class AdaLayerNormZero(nn.Module):
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class AdaLayerNormZeroPruned(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""
def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
super().__init__()
if num_embeddings is not None:
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
else:
self.emb = None
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
elif norm_type == "fp32_layer_norm":
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
else:
raise ValueError(
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
)
def forward(
self,
x: torch.Tensor,
timestep: Optional[torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
hidden_dtype: Optional[torch.dtype] = None,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if self.emb is not None:
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
scale_msa, shift_msa, gate_msa, scale_mlp, shift_mlp, gate_mlp = emb.chunk(6, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class AdaLayerNormZeroSingle(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).
@@ -203,6 +243,35 @@ class AdaLayerNormZeroSingle(nn.Module):
return x, gate_msa
class AdaLayerNormZeroSinglePruned(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
super().__init__()
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
else:
raise ValueError(
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
)
def forward(
self,
x: torch.Tensor,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
scale_msa, shift_msa, gate_msa = emb.chunk(3, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa
class LuminaRMSNormZero(nn.Module):
"""
Norm layer adaptive RMS normalization zero.
@@ -305,6 +374,50 @@ class AdaGroupNorm(nn.Module):
return x
class AdaLayerNormContinuousPruned(nn.Module):
r"""
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
Args:
embedding_dim (`int`): Embedding dimension to use during projection.
conditioning_embedding_dim (`int`): Dimension of the input condition.
elementwise_affine (`bool`, defaults to `True`):
Boolean flag to denote if affine transformation should be applied.
eps (`float`, defaults to 1e-5): Epsilon factor.
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
norm_type (`str`, defaults to `"layer_norm"`):
Normalization layer to use. Values supported: "layer_norm", "rms_norm".
"""
def __init__(
self,
embedding_dim: int,
conditioning_embedding_dim: int,
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
# However, this is how it was implemented in the original code, and it's rather likely you should
# set `elementwise_affine` to False.
elementwise_affine=True,
eps=1e-5,
bias=True,
norm_type="layer_norm",
):
super().__init__()
if norm_type == "layer_norm":
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
elif norm_type == "rms_norm":
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
else:
raise ValueError(f"unknown norm_type {norm_type}")
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
shift, scale = torch.chunk(emb.to(x.dtype), 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
class AdaLayerNormContinuous(nn.Module):
r"""
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).

View File

@@ -33,22 +33,49 @@ from ..attention_processor import (
FusedFluxAttnProcessor2_0,
)
from ..cache_utils import CacheMixin
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..embeddings import (
CombinedTimestepGuidanceTextProjEmbeddings,
CombinedTimestepTextProjChromaEmbeddings,
CombinedTimestepTextProjEmbeddings,
FluxPosEmbed,
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
from ..normalization import (
AdaLayerNormContinuous,
AdaLayerNormContinuousPruned,
AdaLayerNormZero,
AdaLayerNormZeroPruned,
AdaLayerNormZeroSingle,
AdaLayerNormZeroSinglePruned,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
INVALID_VARIANT_ERRMSG = "`variant` must be `'flux' or `'chroma'`."
@maybe_allow_in_graph
class FluxSingleTransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
mlp_ratio: float = 4.0,
variant: str = "flux",
):
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)
self.norm = AdaLayerNormZeroSingle(dim)
if variant == "flux":
self.norm = AdaLayerNormZeroSingle(dim)
elif variant == "chroma":
self.norm = AdaLayerNormZeroSinglePruned(dim)
else:
raise ValueError(INVALID_VARIANT_ERRMSG)
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
@@ -106,12 +133,24 @@ class FluxSingleTransformerBlock(nn.Module):
@maybe_allow_in_graph
class FluxTransformerBlock(nn.Module):
def __init__(
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
qk_norm: str = "rms_norm",
eps: float = 1e-6,
variant: str = "flux",
):
super().__init__()
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
if variant == "flux":
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
elif variant == "chroma":
self.norm1 = AdaLayerNormZeroPruned(dim)
self.norm1_context = AdaLayerNormZeroPruned(dim)
else:
raise ValueError(INVALID_VARIANT_ERRMSG)
self.attn = Attention(
query_dim=dim,
@@ -141,10 +180,11 @@ class FluxTransformerBlock(nn.Module):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
temb_img, temb_txt = temb[:, :6], temb[:, 6:]
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb_img)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
encoder_hidden_states, emb=temb_txt
)
joint_attention_kwargs = joint_attention_kwargs or {}
# Attention.
@@ -241,7 +281,11 @@ class FluxTransformer2DModel(
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int] = (16, 56, 56),
axes_dims_rope: Tuple[int, ...] = (16, 56, 56),
variant: str = "flux",
approximator_in_factor: int = 16,
approximator_hidden_dim: int = 5120,
approximator_layers: int = 5,
):
super().__init__()
self.out_channels = out_channels or in_channels
@@ -249,12 +293,23 @@ class FluxTransformer2DModel(
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
text_time_guidance_cls = (
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
)
self.time_text_embed = text_time_guidance_cls(
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
if variant == "flux":
text_time_guidance_cls = (
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
)
self.time_text_embed = text_time_guidance_cls(
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
elif variant == "chroma":
self.time_text_embed = CombinedTimestepTextProjChromaEmbeddings(
factor=approximator_in_factor,
hidden_dim=approximator_hidden_dim,
out_dim=3 * num_single_layers + 2 * 6 * num_layers + 2,
embedding_dim=self.inner_dim,
n_layers=approximator_layers,
)
else:
raise ValueError(INVALID_VARIANT_ERRMSG)
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
@@ -265,6 +320,7 @@ class FluxTransformer2DModel(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
variant=variant,
)
for _ in range(num_layers)
]
@@ -276,12 +332,14 @@ class FluxTransformer2DModel(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
variant=variant,
)
for _ in range(num_single_layers)
]
)
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
norm_out_cls = AdaLayerNormContinuous if variant != "chroma" else AdaLayerNormContinuousPruned
self.norm_out = norm_out_cls(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
@@ -442,19 +500,22 @@ class FluxTransformer2DModel(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
is_chroma = isinstance(self.time_text_embed, CombinedTimestepTextProjChromaEmbeddings)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
if not is_chroma:
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
else:
pooled_temb = self.time_text_embed(timestep, guidance, pooled_projections)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if txt_ids.ndim == 3:
@@ -479,6 +540,12 @@ class FluxTransformer2DModel(
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
for index_block, block in enumerate(self.transformer_blocks):
if is_chroma:
start_idx1 = 3 * len(self.single_transformer_blocks) + 6 * index_block
start_idx2 = start_idx1 + 6 * len(self.transformer_blocks)
temb = torch.cat(
(pooled_temb[:, start_idx1 : start_idx1 + 6], pooled_temb[:, start_idx2 : start_idx2 + 6]), dim=1
)
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
@@ -511,6 +578,9 @@ class FluxTransformer2DModel(
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if is_chroma:
start_idx = 3 * index_block
temb = pooled_temb[:, start_idx : start_idx + 3]
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
@@ -538,6 +608,8 @@ class FluxTransformer2DModel(
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
if is_chroma:
temb = pooled_temb[:, -2:]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)