diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index 597e1d34ee..62f739fc6b 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -416,7 +416,7 @@ class ChromaTransformer2DModel( num_attention_heads: int = 24, joint_attention_dim: int = 4096, axes_dims_rope: Tuple[int, ...] = (16, 56, 56), - approximator_in_factor: int = 16, + approximator_num_channels: int = 64, approximator_hidden_dim: int = 5120, approximator_layers: int = 5, ): @@ -427,11 +427,11 @@ class ChromaTransformer2DModel( self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) self.time_text_embed = ChromaCombinedTimestepTextProjEmbeddings( - num_channels=approximator_in_factor, + num_channels=approximator_num_channels // 4, out_dim=3 * num_single_layers + 2 * 6 * num_layers + 2, ) self.distilled_guidance_layer = ChromaApproximator( - in_dim=64, + in_dim=approximator_num_channels, out_dim=self.inner_dim, hidden_dim=approximator_hidden_dim, n_layers=approximator_layers, diff --git a/tests/models/transformers/test_models_transformer_chroma.py b/tests/models/transformers/test_models_transformer_chroma.py index 37be388d03..d1a061ce10 100644 --- a/tests/models/transformers/test_models_transformer_chroma.py +++ b/tests/models/transformers/test_models_transformer_chroma.py @@ -128,7 +128,7 @@ class ChromaTransformerTests(ModelTesterMixin, unittest.TestCase): "num_attention_heads": 2, "joint_attention_dim": 32, "axes_dims_rope": [4, 4, 8], - "approximator_in_factor": 32, + "approximator_num_channels": 8, "approximator_hidden_dim": 16, "approximator_layers": 1, }