mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
remove mask function
This commit is contained in:
@@ -198,13 +198,6 @@ class ChromaPipeline(
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.default_sample_size = 128
|
||||
|
||||
def _get_chroma_attn_mask(self, length: torch.Tensor, max_sequence_length: int) -> torch.Tensor:
|
||||
attention_mask = torch.zeros((length.shape[0], max_sequence_length), dtype=torch.bool, device=length.device)
|
||||
for i, n_tokens in enumerate(length):
|
||||
n_tokens = torch.max(n_tokens + 1, max_sequence_length)
|
||||
attention_mask[i, :n_tokens] = True
|
||||
return attention_mask
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
@@ -234,12 +227,12 @@ class ChromaPipeline(
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
text_inputs.attention_mask[:, : text_inputs.length + 1] = 1.0
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
output_hidden_states=False,
|
||||
attention_mask=(
|
||||
self._get_chroma_attn_mask(text_inputs.length, max_sequence_length).to(device)
|
||||
),
|
||||
attention_mask=text_inputs.attention_mask.to(device),
|
||||
)[0]
|
||||
|
||||
dtype = self.text_encoder.dtype
|
||||
|
||||
Reference in New Issue
Block a user