diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index 88b435fb29..09883f54c7 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -198,13 +198,6 @@ class ChromaPipeline( self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) self.default_sample_size = 128 - def _get_chroma_attn_mask(self, length: torch.Tensor, max_sequence_length: int) -> torch.Tensor: - attention_mask = torch.zeros((length.shape[0], max_sequence_length), dtype=torch.bool, device=length.device) - for i, n_tokens in enumerate(length): - n_tokens = torch.max(n_tokens + 1, max_sequence_length) - attention_mask[i, :n_tokens] = True - return attention_mask - def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, @@ -234,12 +227,12 @@ class ChromaPipeline( text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + text_inputs.attention_mask[:, : text_inputs.length + 1] = 1.0 + prompt_embeds = self.text_encoder( text_input_ids.to(device), output_hidden_states=False, - attention_mask=( - self._get_chroma_attn_mask(text_inputs.length, max_sequence_length).to(device) - ), + attention_mask=text_inputs.attention_mask.to(device), )[0] dtype = self.text_encoder.dtype