From 178c4ec928254b8e7f2bac7363ffa877e48aba00 Mon Sep 17 00:00:00 2001 From: BuildTools Date: Fri, 13 Jun 2025 07:46:29 -0600 Subject: [PATCH 1/5] push local changes, fix docs --- docs/source/en/api/pipelines/chroma.md | 1 + .../transformers/test_models_transformer_chroma.py | 2 +- tests/pipelines/chroma/test_pipeline_chroma.py | 11 +++++------ 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/source/en/api/pipelines/chroma.md b/docs/source/en/api/pipelines/chroma.md index 0f8c9940f2..22448d88e0 100644 --- a/docs/source/en/api/pipelines/chroma.md +++ b/docs/source/en/api/pipelines/chroma.md @@ -25,6 +25,7 @@ Original model checkpoints for Chroma can be found [here](https://huggingface.co Chroma can use all the same optimizations as Flux. + ## Inference (Single File) diff --git a/tests/models/transformers/test_models_transformer_chroma.py b/tests/models/transformers/test_models_transformer_chroma.py index 5e177cca44..d1a061ce10 100644 --- a/tests/models/transformers/test_models_transformer_chroma.py +++ b/tests/models/transformers/test_models_transformer_chroma.py @@ -125,7 +125,7 @@ class ChromaTransformerTests(ModelTesterMixin, unittest.TestCase): "num_layers": 1, "num_single_layers": 1, "attention_head_dim": 16, - "num_attention_heads": 192, + "num_attention_heads": 2, "joint_attention_dim": 32, "axes_dims_rope": [4, 4, 8], "approximator_num_channels": 8, diff --git a/tests/pipelines/chroma/test_pipeline_chroma.py b/tests/pipelines/chroma/test_pipeline_chroma.py index c47719d3e4..e8c2944a9c 100644 --- a/tests/pipelines/chroma/test_pipeline_chroma.py +++ b/tests/pipelines/chroma/test_pipeline_chroma.py @@ -39,14 +39,13 @@ class ChromaPipelineFastTests( in_channels=4, num_layers=num_layers, num_single_layers=num_single_layers, - attention_head_dim=4, - num_attention_heads=4, + attention_head_dim=16, + num_attention_heads=2, joint_attention_dim=32, axes_dims_rope=[4, 4, 8], - approximator_in_factor=1, - approximator_hidden_dim=32, - approximator_out_dim=64, - approximator_layers=5, + approximator_num_channels=8, + approximator_hidden_dim=16, + approximator_layers=1, ) torch.manual_seed(0) From 16b6e33916e5e39916ec2a99b77a4d4b3e43c7a7 Mon Sep 17 00:00:00 2001 From: BuildTools Date: Fri, 13 Jun 2025 08:11:12 -0600 Subject: [PATCH 2/5] add encoder test, remove pooled dim --- src/diffusers/models/transformers/transformer_chroma.py | 1 - tests/pipelines/chroma/test_pipeline_chroma.py | 6 ------ 2 files changed, 7 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index 1ca6cf02fa..2b415cfed2 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -409,7 +409,6 @@ class ChromaTransformer2DModel( attention_head_dim: int = 128, num_attention_heads: int = 24, joint_attention_dim: int = 4096, - pooled_projection_dim: int = 768, axes_dims_rope: Tuple[int, ...] = (16, 56, 56), approximator_num_channels: int = 64, approximator_hidden_dim: int = 5120, diff --git a/tests/pipelines/chroma/test_pipeline_chroma.py b/tests/pipelines/chroma/test_pipeline_chroma.py index e8c2944a9c..1f20e2081d 100644 --- a/tests/pipelines/chroma/test_pipeline_chroma.py +++ b/tests/pipelines/chroma/test_pipeline_chroma.py @@ -168,9 +168,3 @@ class ChromaPipelineFastTests( image = pipe(**inputs).images[0] output_height, output_width, _ = image.shape assert (output_height, output_width) == (expected_height, expected_width) - - @unittest.skip( - "Chroma uses Flux encode_prompt but uses CFG. This test is incompatible with the pipeline since the test case does not use a negative prompt embeds" - ) - def test_encode_prompt_works_in_isolation(self): - pass From 06fb9957a7c7a4d6238d931740b7d37735d41d6b Mon Sep 17 00:00:00 2001 From: BuildTools Date: Fri, 13 Jun 2025 08:38:02 -0600 Subject: [PATCH 3/5] default proj dim --- tests/pipelines/test_pipelines_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 22c17085a2..978001e8c3 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -544,7 +544,7 @@ class FluxIPAdapterTesterMixin: components = self.get_dummy_components() pipe = self.pipeline_class(**components).to(torch_device) pipe.set_progress_bar_config(disable=None) - image_embed_dim = pipe.transformer.config.pooled_projection_dim + image_embed_dim = pipe.transformer.config.pooled_projection_dim or 768 # forward pass without ip adapter inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) From 49a4c8bc2209266bb867a83c3c1fd134cc7c8c9d Mon Sep 17 00:00:00 2001 From: BuildTools Date: Fri, 13 Jun 2025 09:41:44 -0600 Subject: [PATCH 4/5] fix tests --- src/diffusers/pipelines/chroma/pipeline_chroma.py | 7 +++++++ .../models/transformers/test_models_transformer_chroma.py | 2 +- tests/models/transformers/test_models_transformer_flux.py | 4 +++- tests/pipelines/test_pipelines_common.py | 6 +++++- 4 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index 8289ce7872..b93b0d328c 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -668,6 +668,13 @@ class ChromaPipeline( batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) + if negative_prompt is not None and isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + elif negative_prompt is not None and isinstance(negative_prompt, list): + if len(negative_prompt) == 1: + negative_prompt = [negative_prompt] * batch_size + else: + raise ValueError("prompt and negative_prompt are lists of unequal size") else: batch_size = prompt_embeds.shape[0] diff --git a/tests/models/transformers/test_models_transformer_chroma.py b/tests/models/transformers/test_models_transformer_chroma.py index d1a061ce10..93df7ca35c 100644 --- a/tests/models/transformers/test_models_transformer_chroma.py +++ b/tests/models/transformers/test_models_transformer_chroma.py @@ -82,7 +82,7 @@ class ChromaTransformerTests(ModelTesterMixin, unittest.TestCase): model_class = ChromaTransformer2DModel main_input_name = "hidden_states" # We override the items here because the transformer under consideration is small. - model_split_percents = [0.7, 0.6, 0.6] + model_split_percents = [0.8, 0.7, 0.7] # Skip setting testing with default: AttnProcessor uses_custom_attn_processor = True diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 33c8765358..036ed2ea30 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -57,7 +57,9 @@ def create_flux_ip_adapter_state_dict(model): image_projection = ImageProjection( cross_attention_dim=model.config["joint_attention_dim"], - image_embed_dim=model.config["pooled_projection_dim"], + image_embed_dim=( + model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768 + ), num_image_text_embeds=4, ) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 978001e8c3..1ced3b5ace 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -544,7 +544,11 @@ class FluxIPAdapterTesterMixin: components = self.get_dummy_components() pipe = self.pipeline_class(**components).to(torch_device) pipe.set_progress_bar_config(disable=None) - image_embed_dim = pipe.transformer.config.pooled_projection_dim or 768 + image_embed_dim = ( + pipe.transformer.config.pooled_projection_dim + if hasattr(pipe.transformer.config, "pooled_projection_dim") + else 768 + ) # forward pass without ip adapter inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) From 3fe4ad67d58d83715bc238f8654f5e90bfc5653c Mon Sep 17 00:00:00 2001 From: BuildTools Date: Fri, 13 Jun 2025 10:51:31 -0600 Subject: [PATCH 5/5] fix equal size list input --- src/diffusers/pipelines/chroma/pipeline_chroma.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index b93b0d328c..46e95348a1 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -673,7 +673,7 @@ class ChromaPipeline( elif negative_prompt is not None and isinstance(negative_prompt, list): if len(negative_prompt) == 1: negative_prompt = [negative_prompt] * batch_size - else: + elif len(prompt) != len(negative_prompt): raise ValueError("prompt and negative_prompt are lists of unequal size") else: batch_size = prompt_embeds.shape[0]