From 1999bffda8372893ce0e3febcb85aa68b1e019ee Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 17 Jun 2025 09:56:53 +0200 Subject: [PATCH] update --- .../pipelines/chroma/pipeline_chroma.py | 45 +++++++++---------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index a86138f569..b73d17a9d2 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -549,29 +549,23 @@ class ChromaPipeline( return latents, latent_image_ids def _prepare_attention_mask( - self, batch_size, sequence_length, dtype, prompt_attention_mask=None, negative_prompt_attention_mask=None + self, + batch_size, + sequence_length, + dtype, + attention_mask=None, ): - attention_mask = None - if prompt_attention_mask is not None: - # Extend the prompt attention mask to account for image tokens in the final sequence - attention_mask = torch.cat( - [prompt_attention_mask, torch.ones(batch_size, sequence_length, device=prompt_attention_mask.device)], - dim=1, - ) - attention_mask = attention_mask.to(dtype) + if attention_mask is None: + return attention_mask - negative_attention_mask = None - if negative_prompt_attention_mask is not None: - negative_attention_mask = torch.cat( - [ - negative_prompt_attention_mask, - torch.ones(batch_size, sequence_length, device=negative_prompt_attention_mask.device), - ], - dim=1, - ) - negative_attention_mask = negative_attention_mask.to(dtype) + # Extend the prompt attention mask to account for image tokens in the final sequence + attention_mask = torch.cat( + [attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)], + dim=1, + ) + attention_mask = attention_mask.to(dtype) - return attention_mask, negative_attention_mask + return attention_mask @property def guidance_scale(self): @@ -787,12 +781,17 @@ class ChromaPipeline( self.scheduler.config.get("max_shift", 1.15), ) - attention_mask, negative_attention_mask = self._prepare_attention_mask( + attention_mask = self._prepare_attention_mask( batch_size=latents.shape[0], sequence_length=image_seq_len, dtype=latents.dtype, - prompt_attention_mask=prompt_attention_mask, - negative_prompt_attention_mask=negative_prompt_attention_mask, + attention_mask=prompt_attention_mask, + ) + negative_attention_mask = self._prepare_attention_mask( + batch_size=latents.shape[0], + sequence_length=image_seq_len, + dtype=latents.dtype, + attention_mask=negative_prompt_attention_mask, ) timesteps, num_inference_steps = retrieve_timesteps(