From 188b0d2a2f3a84ab95df3d045d6e1bdb928efa0d Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 16 Jun 2025 19:32:19 +0200 Subject: [PATCH] update --- src/diffusers/pipelines/chroma/pipeline_chroma.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index 5df1825628..9048efec8d 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -235,6 +235,7 @@ class ChromaPipeline( dtype = self.text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + attention_mask = attention_mask.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape @@ -331,6 +332,7 @@ class ChromaPipeline( max_sequence_length=max_sequence_length, device=device, ) + negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) if self.text_encoder is not None: @@ -547,7 +549,9 @@ class ChromaPipeline( self, batch_size, sequence_length, prompt_attention_mask, negative_prompt_attention_mask=None ): device = prompt_attention_mask.device - attention_mask = torch.cat([prompt_attention_mask, torch.ones(batch_size, sequence_length, device=device)]) + attention_mask = torch.cat( + [prompt_attention_mask, torch.ones(batch_size, sequence_length, device=device)], dim=1 + ) negative_attention_mask = None if negative_prompt_attention_mask is not None: @@ -555,7 +559,8 @@ class ChromaPipeline( [ negative_prompt_attention_mask, torch.ones(batch_size, sequence_length, device=device), - ] + ], + dim=1, ) return attention_mask, negative_attention_mask @@ -759,7 +764,7 @@ class ChromaPipeline( latents, ) - prompt_attention_mask, negative_prompt_attention_mask = self._prepare_attention_mask( + attention_mask, negative_attention_mask = self._prepare_attention_mask( latents.shape[0], latents.shape[1], prompt_attention_mask, @@ -837,7 +842,7 @@ class ChromaPipeline( encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, - attention_mask=prompt_attention_mask, + attention_mask=attention_mask.to(latents.dtype), joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] @@ -851,7 +856,7 @@ class ChromaPipeline( encoder_hidden_states=negative_prompt_embeds, txt_ids=negative_text_ids, img_ids=latent_image_ids, - attention_mask=negative_prompt_attention_mask, + attention_mask=negative_attention_mask.to(latents.dtype), joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0]