From 544dad4c2567cbd9d2d908f3e14b2df6f591b309 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 17 Jun 2025 05:54:38 +0200 Subject: [PATCH] update --- src/diffusers/pipelines/chroma/pipeline_chroma.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index 02b7184114..a86138f569 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -549,7 +549,7 @@ class ChromaPipeline( return latents, latent_image_ids def _prepare_attention_mask( - self, batch_size, sequence_length, prompt_attention_mask, negative_prompt_attention_mask=None + self, batch_size, sequence_length, dtype, prompt_attention_mask=None, negative_prompt_attention_mask=None ): attention_mask = None if prompt_attention_mask is not None: @@ -558,6 +558,7 @@ class ChromaPipeline( [prompt_attention_mask, torch.ones(batch_size, sequence_length, device=prompt_attention_mask.device)], dim=1, ) + attention_mask = attention_mask.to(dtype) negative_attention_mask = None if negative_prompt_attention_mask is not None: @@ -568,6 +569,7 @@ class ChromaPipeline( ], dim=1, ) + negative_attention_mask = negative_attention_mask.to(dtype) return attention_mask, negative_attention_mask @@ -788,6 +790,7 @@ class ChromaPipeline( attention_mask, negative_attention_mask = self._prepare_attention_mask( batch_size=latents.shape[0], sequence_length=image_seq_len, + dtype=latents.dtype, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, ) @@ -853,7 +856,7 @@ class ChromaPipeline( encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, - attention_mask=attention_mask.to(latents.dtype) if attention_mask is not None else None, + attention_mask=attention_mask, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] @@ -867,9 +870,7 @@ class ChromaPipeline( encoder_hidden_states=negative_prompt_embeds, txt_ids=negative_text_ids, img_ids=latent_image_ids, - attention_mask=negative_attention_mask.to(latents.dtype) - if negative_attention_mask is not None - else None, + attention_mask=negative_attention_mask, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0]