From 22ecd19f91039705f90a81c5cc1afa2d8413a26b Mon Sep 17 00:00:00 2001 From: Edna <88869424+Ednaordinary@users.noreply.github.com> Date: Mon, 9 Jun 2025 21:32:52 -0600 Subject: [PATCH] take out variant stuff --- .../models/transformers/transformer_chroma.py | 119 ++++++------------ 1 file changed, 36 insertions(+), 83 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index c542bcaacc..1f726f5cb4 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -43,40 +43,27 @@ from ..embeddings import ( from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import ( - AdaLayerNormContinuous, AdaLayerNormContinuousPruned, - AdaLayerNormZero, AdaLayerNormZeroPruned, - AdaLayerNormZeroSingle, AdaLayerNormZeroSinglePruned, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name -INVALID_VARIANT_ERRMSG = "`variant` must be `'flux' or `'chroma'`." - @maybe_allow_in_graph -class FluxSingleTransformerBlock(nn.Module): +class ChromaSingleTransformerBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0, - variant: str = "flux", ): super().__init__() self.mlp_hidden_dim = int(dim * mlp_ratio) - - if variant == "flux": - self.norm = AdaLayerNormZeroSingle(dim) - elif variant == "chroma": - self.norm = AdaLayerNormZeroSinglePruned(dim) - else: - raise ValueError(INVALID_VARIANT_ERRMSG) - + self.norm = AdaLayerNormZeroSinglePruned(dim) self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) self.act_mlp = nn.GELU(approximate="tanh") self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) @@ -132,7 +119,7 @@ class FluxSingleTransformerBlock(nn.Module): @maybe_allow_in_graph -class FluxTransformerBlock(nn.Module): +class ChromaTransformerBlock(nn.Module): def __init__( self, dim: int, @@ -140,18 +127,10 @@ class FluxTransformerBlock(nn.Module): attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6, - variant: str = "flux", ): super().__init__() - - if variant == "flux": - self.norm1 = AdaLayerNormZero(dim) - self.norm1_context = AdaLayerNormZero(dim) - elif variant == "chroma": - self.norm1 = AdaLayerNormZeroPruned(dim) - self.norm1_context = AdaLayerNormZeroPruned(dim) - else: - raise ValueError(INVALID_VARIANT_ERRMSG) + self.norm1 = AdaLayerNormZeroPruned(dim) + self.norm1_context = AdaLayerNormZeroPruned(dim) self.attn = Attention( query_dim=dim, @@ -231,13 +210,13 @@ class FluxTransformerBlock(nn.Module): return encoder_hidden_states, hidden_states -class FluxTransformer2DModel( +class ChromaTransformer2DModel( ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin ): """ - The Transformer model introduced in Flux. + The Transformer model introduced in Flux, modified for Chroma. - Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + Reference: https://huggingface.co/lodestones/Chroma Args: patch_size (`int`, defaults to `1`): @@ -266,7 +245,7 @@ class FluxTransformer2DModel( """ _supports_gradient_checkpointing = True - _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] + _no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"] @register_to_config @@ -283,7 +262,6 @@ class FluxTransformer2DModel( pooled_projection_dim: int = 768, guidance_embeds: bool = False, axes_dims_rope: Tuple[int, ...] = (16, 56, 56), - variant: str = "flux", approximator_in_factor: int = 16, approximator_hidden_dim: int = 5120, approximator_layers: int = 5, @@ -294,31 +272,21 @@ class FluxTransformer2DModel( self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) - if variant == "flux": - text_time_guidance_cls = ( - CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings - ) - self.time_text_embed = text_time_guidance_cls( - embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim - ) - elif variant == "chroma": - self.time_text_embed = CombinedTimestepTextProjChromaEmbeddings( - factor=approximator_in_factor, - hidden_dim=approximator_hidden_dim, - out_dim=3 * num_single_layers + 2 * 6 * num_layers + 2, - embedding_dim=self.inner_dim, - n_layers=approximator_layers, - ) - self.distilled_guidance_layer = ChromaApproximator(in_dim=64, out_dim=3072, hidden_dim=5120, n_layers=5) - else: - raise ValueError(INVALID_VARIANT_ERRMSG) + self.time_text_embed = CombinedTimestepTextProjChromaEmbeddings( + factor=approximator_in_factor, + hidden_dim=approximator_hidden_dim, + out_dim=3 * num_single_layers + 2 * 6 * num_layers + 2, + embedding_dim=self.inner_dim, + n_layers=approximator_layers, + ) + self.distilled_guidance_layer = ChromaApproximator(in_dim=64, out_dim=3072, hidden_dim=5120, n_layers=5) self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) self.x_embedder = nn.Linear(in_channels, self.inner_dim) self.transformer_blocks = nn.ModuleList( [ - FluxTransformerBlock( + ChromaTransformerBlock( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, @@ -330,7 +298,7 @@ class FluxTransformer2DModel( self.single_transformer_blocks = nn.ModuleList( [ - FluxSingleTransformerBlock( + ChromaSingleTransformerBlock( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, @@ -340,16 +308,12 @@ class FluxTransformer2DModel( ] ) - norm_out_cls = AdaLayerNormContinuous if variant != "chroma" else AdaLayerNormContinuousPruned + norm_out_cls = AdaLayerNormContinuousPruned self.norm_out = norm_out_cls(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) self.gradient_checkpointing = False - @property - def is_chroma(self) -> bool: - return isinstance(self.time_text_embed, CombinedTimestepTextProjChromaEmbeddings) - @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: @@ -506,22 +470,14 @@ class FluxTransformer2DModel( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) - is_chroma = self.is_chroma hidden_states = self.x_embedder(hidden_states) timestep = timestep.to(hidden_states.dtype) * 1000 if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 - if not is_chroma: - temb = ( - self.time_text_embed(timestep, pooled_projections) - if guidance is None - else self.time_text_embed(timestep, guidance, pooled_projections) - ) - else: - input_vec = self.time_text_embed(timestep, guidance, pooled_projections) - pooled_temb = self.distilled_guidance_layer(input_vec) + input_vec = self.time_text_embed(timestep, guidance, pooled_projections) + pooled_temb = self.distilled_guidance_layer(input_vec) encoder_hidden_states = self.context_embedder(encoder_hidden_states) @@ -547,18 +503,17 @@ class FluxTransformer2DModel( joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) for index_block, block in enumerate(self.transformer_blocks): - if is_chroma: - img_offset = 3 * len(self.single_transformer_blocks) - txt_offset = img_offset + 6 * len(self.transformer_blocks) - img_modulation = img_offset + 6 * index_block - text_modulation = txt_offset + 6 * index_block - temb = torch.cat( - ( - pooled_temb[:, img_modulation : img_modulation + 6], - pooled_temb[:, text_modulation : text_modulation + 6], - ), - dim=1, - ) + img_offset = 3 * len(self.single_transformer_blocks) + txt_offset = img_offset + 6 * len(self.transformer_blocks) + img_modulation = img_offset + 6 * index_block + text_modulation = txt_offset + 6 * index_block + temb = torch.cat( + ( + pooled_temb[:, img_modulation : img_modulation + 6], + pooled_temb[:, text_modulation : text_modulation + 6], + ), + dim=1, + ) if torch.is_grad_enabled() and self.gradient_checkpointing: encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, @@ -591,9 +546,8 @@ class FluxTransformer2DModel( hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): - if is_chroma: - start_idx = 3 * index_block - temb = pooled_temb[:, start_idx : start_idx + 3] + start_idx = 3 * index_block + temb = pooled_temb[:, start_idx : start_idx + 3] if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = self._gradient_checkpointing_func( block, @@ -621,8 +575,7 @@ class FluxTransformer2DModel( hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] - if is_chroma: - temb = pooled_temb[:, -2:] + temb = pooled_temb[:, -2:] hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states)