From 723d149dc1dad0db009abcb210e671a775b23db6 Mon Sep 17 00:00:00 2001 From: leffff Date: Fri, 10 Oct 2025 17:00:23 +0000 Subject: [PATCH] add multiprompt support --- .../kandinsky5/pipeline_kandinsky.py | 38 ++++++++++++------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index cea079251b..a417d99675 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -269,18 +269,21 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): output_hidden_states=True, )["hidden_states"][-1][:, crop_start:] + batch_size = len(prompt) + attention_mask = inputs["attention_mask"][:, crop_start:] - embeds = embeds[attention_mask.bool()] cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to(dtype=torch.int32) - # duplicate for each generation per prompt - batch_size = len(prompt) - seq_len = embeds.shape[0] // batch_size - embeds = embeds.reshape(batch_size, seq_len, -1) - embeds = embeds.repeat(1, num_videos_per_prompt, 1) - embeds = embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) +# # duplicate for each generation per prompt +# seq_len = embeds.shape[0] // batch_size +# embeds = embeds.reshape(batch_size, seq_len, -1) +# embeds = embeds.repeat(1, num_videos_per_prompt, 1) +# embeds = embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) +# print(embeds.shape, cu_seqlens, "ENCODE PROMPT") + embeds = torch.cat([embeds[i].unsqueeze(dim=0).repeat(num_videos_per_prompt, 1, 1) for i in range(batch_size)], dim=0) + return embeds.to(dtype), cu_seqlens def _encode_prompt_clip( @@ -679,10 +682,10 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): torch.arange(width // self.vae_scale_factor_spatial // 2, device=device), ] - text_rope_pos = torch.arange(prompt_cu_seqlens[-1].item(), device=device) + text_rope_pos = torch.arange(prompt_cu_seqlens.diff().max().item(), device=device) negative_text_rope_pos = ( - torch.arange(negative_cu_seqlens[-1].item(), device=device) + torch.arange(negative_cu_seqlens.diff().max().item(), device=device) if negative_cu_seqlens is not None else None ) @@ -696,12 +699,19 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): if self.interrupt: continue - timestep = t.unsqueeze(0).flatten() + timestep = t.unsqueeze(0).repeat(batch_size * num_videos_per_prompt) - - - # Predict noise residual - # with torch.autocast(device_type="cuda", dtype=dtype): + # Predict noise residual + # print( + # latents.shape, + # prompt_embeds_dict["text_embeds"].shape, + # prompt_embeds_dict["pooled_embed"].shape, + # timestep.shape, + # [el.shape for el in visual_rope_pos], + # text_rope_pos.shape, + # prompt_cu_seqlens, + # ) + pred_velocity = self.transformer( hidden_states=latents.to(dtype), encoder_hidden_states=prompt_embeds_dict["text_embeds"].to(dtype),