diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index c3de16a363..02eee74423 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -418,6 +418,7 @@ class ChromaTransformer2DModel( attention_head_dim: int = 128, num_attention_heads: int = 24, joint_attention_dim: int = 4096, + pooled_projection_dim: int = 768, axes_dims_rope: Tuple[int, ...] = (16, 56, 56), approximator_num_channels: int = 64, approximator_hidden_dim: int = 5120, diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index 9cb7a19081..b6c398ead0 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -219,25 +219,23 @@ class ChromaPipeline( text_inputs = self.tokenizer( prompt, - padding=False, + padding=True, max_length=max_sequence_length, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) - pad_token_id = self.tokenizer.pad_token_id - text_input_ids = torch.cat( - [ - text_inputs.input_ids, - torch.full((text_inputs.input_ids.size(0), 1), pad_token_id, dtype=text_inputs.input_ids.dtype), - ], - dim=1, - ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask.clone() + + # Chroma requires the attention mask to include one padding token + seq_lengths = attention_mask.sum(dim=1) + mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1) + attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long() prompt_embeds = self.text_encoder( - text_input_ids.to(device), - output_hidden_states=False, + text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device) )[0] dtype = self.text_encoder.dtype diff --git a/tests/pipelines/chroma/chroma.py b/tests/pipelines/chroma/chroma.py index 6f3e0ea807..5025769c9a 100644 --- a/tests/pipelines/chroma/chroma.py +++ b/tests/pipelines/chroma/chroma.py @@ -19,7 +19,6 @@ from ..test_pipelines_common import ( FasterCacheTesterMixin, FluxIPAdapterTesterMixin, PipelineTesterMixin, - PyramidAttentionBroadcastTesterMixin, check_qkv_fusion_matches_attn_procs_length, check_qkv_fusion_processors_exist, ) @@ -29,11 +28,10 @@ class ChromaPipelineFastTests( unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin, - PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, ): pipeline_class = ChromaPipeline - params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) + params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"]) batch_params = frozenset(["prompt"]) # there is no xformers processor for Flux @@ -46,7 +44,7 @@ class ChromaPipelineFastTests( spatial_attention_timestep_skip_range=(-1, 901), unconditional_batch_skip_range=2, attention_weight_callback=lambda _: 0.5, - is_guidance_distilled=True, + is_guidance_distilled=False, ) def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): @@ -182,3 +180,9 @@ class ChromaPipelineFastTests( image = pipe(**inputs).images[0] output_height, output_width, _ = image.shape assert (output_height, output_width) == (expected_height, expected_width) + + @unittest.skip( + "Chroma uses Flux encode_prompt but uses CFG. This test is incompatible with the pipeline since the test case does not use a negative prompt embeds" + ) + def test_encode_prompt_works_in_isolation(self): + pass diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 91ffc0ae53..22c17085a2 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -307,6 +307,7 @@ class IPAdapterTesterMixin: # forward pass without ip adapter inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + __import__("ipdb").set_trace() if expected_pipe_slice is None: output_without_adapter = pipe(**inputs)[0] else: @@ -521,7 +522,8 @@ class FluxIPAdapterTesterMixin: def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]): inputs["negative_prompt"] = "" - inputs["true_cfg_scale"] = 4.0 + if "true_cfg_scale" in inspect.signature(self.pipeline_class.__call__).parameters: + inputs["true_cfg_scale"] = 4.0 inputs["output_type"] = "np" inputs["return_dict"] = False return inputs