mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -549,6 +549,8 @@ class ChromaPipeline(
|
||||
self, batch_size, sequence_length, prompt_attention_mask, negative_prompt_attention_mask=None
|
||||
):
|
||||
device = prompt_attention_mask.device
|
||||
|
||||
# Extend the prompt attention mask to account for image tokens in the final sequence
|
||||
attention_mask = torch.cat(
|
||||
[prompt_attention_mask, torch.ones(batch_size, sequence_length, device=device)], dim=1
|
||||
)
|
||||
@@ -764,13 +766,6 @@ class ChromaPipeline(
|
||||
latents,
|
||||
)
|
||||
|
||||
attention_mask, negative_attention_mask = self._prepare_attention_mask(
|
||||
latents.shape[0],
|
||||
latents.shape[1],
|
||||
prompt_attention_mask,
|
||||
negative_prompt_attention_mask,
|
||||
)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
image_seq_len = latents.shape[1]
|
||||
@@ -781,6 +776,14 @@ class ChromaPipeline(
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
|
||||
attention_mask, negative_attention_mask = self._prepare_attention_mask(
|
||||
batch_size=latents.shape[0],
|
||||
sequence_length=image_seq_len,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
)
|
||||
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
|
||||
Reference in New Issue
Block a user