1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Dhruv Nair
2025-06-16 19:32:19 +02:00
parent 9019e92899
commit 188b0d2a2f

View File

@@ -235,6 +235,7 @@ class ChromaPipeline(
dtype = self.text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
attention_mask = attention_mask.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
@@ -331,6 +332,7 @@ class ChromaPipeline(
max_sequence_length=max_sequence_length,
device=device,
)
negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
if self.text_encoder is not None:
@@ -547,7 +549,9 @@ class ChromaPipeline(
self, batch_size, sequence_length, prompt_attention_mask, negative_prompt_attention_mask=None
):
device = prompt_attention_mask.device
attention_mask = torch.cat([prompt_attention_mask, torch.ones(batch_size, sequence_length, device=device)])
attention_mask = torch.cat(
[prompt_attention_mask, torch.ones(batch_size, sequence_length, device=device)], dim=1
)
negative_attention_mask = None
if negative_prompt_attention_mask is not None:
@@ -555,7 +559,8 @@ class ChromaPipeline(
[
negative_prompt_attention_mask,
torch.ones(batch_size, sequence_length, device=device),
]
],
dim=1,
)
return attention_mask, negative_attention_mask
@@ -759,7 +764,7 @@ class ChromaPipeline(
latents,
)
prompt_attention_mask, negative_prompt_attention_mask = self._prepare_attention_mask(
attention_mask, negative_attention_mask = self._prepare_attention_mask(
latents.shape[0],
latents.shape[1],
prompt_attention_mask,
@@ -837,7 +842,7 @@ class ChromaPipeline(
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
attention_mask=prompt_attention_mask,
attention_mask=attention_mask.to(latents.dtype),
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
@@ -851,7 +856,7 @@ class ChromaPipeline(
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
img_ids=latent_image_ids,
attention_mask=negative_prompt_attention_mask,
attention_mask=negative_attention_mask.to(latents.dtype),
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]