1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

add multiprompt support

This commit is contained in:
leffff
2025-10-10 17:00:23 +00:00
parent 86b6c2b686
commit 723d149dc1

View File

@@ -269,18 +269,21 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
output_hidden_states=True,
)["hidden_states"][-1][:, crop_start:]
batch_size = len(prompt)
attention_mask = inputs["attention_mask"][:, crop_start:]
embeds = embeds[attention_mask.bool()]
cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0)
cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to(dtype=torch.int32)
# duplicate for each generation per prompt
batch_size = len(prompt)
seq_len = embeds.shape[0] // batch_size
embeds = embeds.reshape(batch_size, seq_len, -1)
embeds = embeds.repeat(1, num_videos_per_prompt, 1)
embeds = embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
# # duplicate for each generation per prompt
# seq_len = embeds.shape[0] // batch_size
# embeds = embeds.reshape(batch_size, seq_len, -1)
# embeds = embeds.repeat(1, num_videos_per_prompt, 1)
# embeds = embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
# print(embeds.shape, cu_seqlens, "ENCODE PROMPT")
embeds = torch.cat([embeds[i].unsqueeze(dim=0).repeat(num_videos_per_prompt, 1, 1) for i in range(batch_size)], dim=0)
return embeds.to(dtype), cu_seqlens
def _encode_prompt_clip(
@@ -679,10 +682,10 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
torch.arange(width // self.vae_scale_factor_spatial // 2, device=device),
]
text_rope_pos = torch.arange(prompt_cu_seqlens[-1].item(), device=device)
text_rope_pos = torch.arange(prompt_cu_seqlens.diff().max().item(), device=device)
negative_text_rope_pos = (
torch.arange(negative_cu_seqlens[-1].item(), device=device)
torch.arange(negative_cu_seqlens.diff().max().item(), device=device)
if negative_cu_seqlens is not None
else None
)
@@ -696,12 +699,19 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
if self.interrupt:
continue
timestep = t.unsqueeze(0).flatten()
timestep = t.unsqueeze(0).repeat(batch_size * num_videos_per_prompt)
# Predict noise residual
# with torch.autocast(device_type="cuda", dtype=dtype):
# Predict noise residual
# print(
# latents.shape,
# prompt_embeds_dict["text_embeds"].shape,
# prompt_embeds_dict["pooled_embed"].shape,
# timestep.shape,
# [el.shape for el in visual_rope_pos],
# text_rope_pos.shape,
# prompt_cu_seqlens,
# )
pred_velocity = self.transformer(
hidden_states=latents.to(dtype),
encoder_hidden_states=prompt_embeds_dict["text_embeds"].to(dtype),