diff --git a/tests/models/transformers/test_models_transformer_chroma.py b/tests/models/transformers/test_models_transformer_chroma.py index 8ed7538aaf..a75e7fab47 100644 --- a/tests/models/transformers/test_models_transformer_chroma.py +++ b/tests/models/transformers/test_models_transformer_chroma.py @@ -125,7 +125,7 @@ class ChromaTransformerTests(ModelTesterMixin, unittest.TestCase): "num_layers": 1, "num_single_layers": 1, "attention_head_dim": 16, - "num_attention_heads": 2, + "num_attention_heads": 192, "joint_attention_dim": 32, "axes_dims_rope": [4, 4, 8], }