From 602af7411e1a095e699f4fa7ec794f1a91094348 Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 16 Jun 2025 23:38:17 +0530 Subject: [PATCH] update --- .../pipelines/chroma/pipeline_chroma.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index 9048efec8d..c1b543656b 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -549,6 +549,8 @@ class ChromaPipeline( self, batch_size, sequence_length, prompt_attention_mask, negative_prompt_attention_mask=None ): device = prompt_attention_mask.device + + # Extend the prompt attention mask to account for image tokens in the final sequence attention_mask = torch.cat( [prompt_attention_mask, torch.ones(batch_size, sequence_length, device=device)], dim=1 ) @@ -764,13 +766,6 @@ class ChromaPipeline( latents, ) - attention_mask, negative_attention_mask = self._prepare_attention_mask( - latents.shape[0], - latents.shape[1], - prompt_attention_mask, - negative_prompt_attention_mask, - ) - # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas image_seq_len = latents.shape[1] @@ -781,6 +776,14 @@ class ChromaPipeline( self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("max_shift", 1.15), ) + + attention_mask, negative_attention_mask = self._prepare_attention_mask( + batch_size=latents.shape[0], + sequence_length=image_seq_len, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps,