mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add multiprompt support
This commit is contained in:
@@ -269,18 +269,21 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
output_hidden_states=True,
|
||||
)["hidden_states"][-1][:, crop_start:]
|
||||
|
||||
batch_size = len(prompt)
|
||||
|
||||
attention_mask = inputs["attention_mask"][:, crop_start:]
|
||||
embeds = embeds[attention_mask.bool()]
|
||||
cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0)
|
||||
cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to(dtype=torch.int32)
|
||||
|
||||
# duplicate for each generation per prompt
|
||||
batch_size = len(prompt)
|
||||
seq_len = embeds.shape[0] // batch_size
|
||||
embeds = embeds.reshape(batch_size, seq_len, -1)
|
||||
embeds = embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
embeds = embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
# # duplicate for each generation per prompt
|
||||
# seq_len = embeds.shape[0] // batch_size
|
||||
# embeds = embeds.reshape(batch_size, seq_len, -1)
|
||||
# embeds = embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
# embeds = embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
# print(embeds.shape, cu_seqlens, "ENCODE PROMPT")
|
||||
embeds = torch.cat([embeds[i].unsqueeze(dim=0).repeat(num_videos_per_prompt, 1, 1) for i in range(batch_size)], dim=0)
|
||||
|
||||
return embeds.to(dtype), cu_seqlens
|
||||
|
||||
def _encode_prompt_clip(
|
||||
@@ -679,10 +682,10 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
torch.arange(width // self.vae_scale_factor_spatial // 2, device=device),
|
||||
]
|
||||
|
||||
text_rope_pos = torch.arange(prompt_cu_seqlens[-1].item(), device=device)
|
||||
text_rope_pos = torch.arange(prompt_cu_seqlens.diff().max().item(), device=device)
|
||||
|
||||
negative_text_rope_pos = (
|
||||
torch.arange(negative_cu_seqlens[-1].item(), device=device)
|
||||
torch.arange(negative_cu_seqlens.diff().max().item(), device=device)
|
||||
if negative_cu_seqlens is not None
|
||||
else None
|
||||
)
|
||||
@@ -696,12 +699,19 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
timestep = t.unsqueeze(0).flatten()
|
||||
timestep = t.unsqueeze(0).repeat(batch_size * num_videos_per_prompt)
|
||||
|
||||
|
||||
|
||||
# Predict noise residual
|
||||
# with torch.autocast(device_type="cuda", dtype=dtype):
|
||||
# Predict noise residual
|
||||
# print(
|
||||
# latents.shape,
|
||||
# prompt_embeds_dict["text_embeds"].shape,
|
||||
# prompt_embeds_dict["pooled_embed"].shape,
|
||||
# timestep.shape,
|
||||
# [el.shape for el in visual_rope_pos],
|
||||
# text_rope_pos.shape,
|
||||
# prompt_cu_seqlens,
|
||||
# )
|
||||
|
||||
pred_velocity = self.transformer(
|
||||
hidden_states=latents.to(dtype),
|
||||
encoder_hidden_states=prompt_embeds_dict["text_embeds"].to(dtype),
|
||||
|
||||
Reference in New Issue
Block a user