diff --git a/docs/source/en/api/pipelines/chroma.md b/docs/source/en/api/pipelines/chroma.md index 0f8c9940f2..22448d88e0 100644 --- a/docs/source/en/api/pipelines/chroma.md +++ b/docs/source/en/api/pipelines/chroma.md @@ -25,6 +25,7 @@ Original model checkpoints for Chroma can be found [here](https://huggingface.co Chroma can use all the same optimizations as Flux. + ## Inference (Single File) diff --git a/tests/models/transformers/test_models_transformer_chroma.py b/tests/models/transformers/test_models_transformer_chroma.py index 5e177cca44..d1a061ce10 100644 --- a/tests/models/transformers/test_models_transformer_chroma.py +++ b/tests/models/transformers/test_models_transformer_chroma.py @@ -125,7 +125,7 @@ class ChromaTransformerTests(ModelTesterMixin, unittest.TestCase): "num_layers": 1, "num_single_layers": 1, "attention_head_dim": 16, - "num_attention_heads": 192, + "num_attention_heads": 2, "joint_attention_dim": 32, "axes_dims_rope": [4, 4, 8], "approximator_num_channels": 8, diff --git a/tests/pipelines/chroma/test_pipeline_chroma.py b/tests/pipelines/chroma/test_pipeline_chroma.py index c47719d3e4..e8c2944a9c 100644 --- a/tests/pipelines/chroma/test_pipeline_chroma.py +++ b/tests/pipelines/chroma/test_pipeline_chroma.py @@ -39,14 +39,13 @@ class ChromaPipelineFastTests( in_channels=4, num_layers=num_layers, num_single_layers=num_single_layers, - attention_head_dim=4, - num_attention_heads=4, + attention_head_dim=16, + num_attention_heads=2, joint_attention_dim=32, axes_dims_rope=[4, 4, 8], - approximator_in_factor=1, - approximator_hidden_dim=32, - approximator_out_dim=64, - approximator_layers=5, + approximator_num_channels=8, + approximator_hidden_dim=16, + approximator_layers=1, ) torch.manual_seed(0)