1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Dhruv Nair
2025-06-17 09:56:53 +02:00
parent 544dad4c25
commit 1999bffda8

View File

@@ -549,29 +549,23 @@ class ChromaPipeline(
return latents, latent_image_ids
def _prepare_attention_mask(
self, batch_size, sequence_length, dtype, prompt_attention_mask=None, negative_prompt_attention_mask=None
self,
batch_size,
sequence_length,
dtype,
attention_mask=None,
):
attention_mask = None
if prompt_attention_mask is not None:
# Extend the prompt attention mask to account for image tokens in the final sequence
attention_mask = torch.cat(
[prompt_attention_mask, torch.ones(batch_size, sequence_length, device=prompt_attention_mask.device)],
dim=1,
)
attention_mask = attention_mask.to(dtype)
if attention_mask is None:
return attention_mask
negative_attention_mask = None
if negative_prompt_attention_mask is not None:
negative_attention_mask = torch.cat(
[
negative_prompt_attention_mask,
torch.ones(batch_size, sequence_length, device=negative_prompt_attention_mask.device),
],
dim=1,
)
negative_attention_mask = negative_attention_mask.to(dtype)
# Extend the prompt attention mask to account for image tokens in the final sequence
attention_mask = torch.cat(
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)],
dim=1,
)
attention_mask = attention_mask.to(dtype)
return attention_mask, negative_attention_mask
return attention_mask
@property
def guidance_scale(self):
@@ -787,12 +781,17 @@ class ChromaPipeline(
self.scheduler.config.get("max_shift", 1.15),
)
attention_mask, negative_attention_mask = self._prepare_attention_mask(
attention_mask = self._prepare_attention_mask(
batch_size=latents.shape[0],
sequence_length=image_seq_len,
dtype=latents.dtype,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
attention_mask=prompt_attention_mask,
)
negative_attention_mask = self._prepare_attention_mask(
batch_size=latents.shape[0],
sequence_length=image_seq_len,
dtype=latents.dtype,
attention_mask=negative_prompt_attention_mask,
)
timesteps, num_inference_steps = retrieve_timesteps(