From 327ab84d1923518ecc5314831254cfd70faf99e1 Mon Sep 17 00:00:00 2001 From: leffff Date: Thu, 16 Oct 2025 06:50:57 +0000 Subject: [PATCH] remove no_grad and simplified prompt paddings --- .../kandinsky5/pipeline_kandinsky.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 3d0d68cbe9..d4470a43d5 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -17,6 +17,7 @@ from typing import Callable, Dict, List, Optional, Union import regex as re import torch +from torch.nn import functional as F from transformers import Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, CLIPTextModel, CLIPTokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback @@ -303,17 +304,19 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): padding=True, ).to(device) - embeds = self.text_encoder( - input_ids=inputs["input_ids"], - return_dict=True, - output_hidden_states=True, - )["hidden_states"][-1][:, crop_start:] - + embeds = self.text_encoder( + input_ids=inputs["input_ids"], + return_dict=True, + output_hidden_states=True, + )["hidden_states"][-1][:, crop_start:] + batch_size = len(prompt) attention_mask = inputs["attention_mask"][:, crop_start:] cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) - cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to(dtype=torch.int32) + # cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to(dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(dtype=torch.int32) + embeds = torch.cat([embeds[i].unsqueeze(dim=0).repeat(num_videos_per_prompt, 1, 1) for i in range(batch_size)], dim=0) return embeds.to(dtype), cu_seqlens @@ -354,8 +357,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): return_tensors="pt", ).to(device) - with torch.no_grad(): - pooled_embed = self.text_encoder_2(**inputs)["pooler_output"] + pooled_embed = self.text_encoder_2(**inputs)["pooler_output"] # duplicate for each generation per prompt batch_size = len(prompt)