mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
remove no_grad and simplified prompt paddings
This commit is contained in:
@@ -17,6 +17,7 @@ from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from transformers import Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
@@ -303,17 +304,19 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
padding=True,
|
||||
).to(device)
|
||||
|
||||
embeds = self.text_encoder(
|
||||
input_ids=inputs["input_ids"],
|
||||
return_dict=True,
|
||||
output_hidden_states=True,
|
||||
)["hidden_states"][-1][:, crop_start:]
|
||||
|
||||
embeds = self.text_encoder(
|
||||
input_ids=inputs["input_ids"],
|
||||
return_dict=True,
|
||||
output_hidden_states=True,
|
||||
)["hidden_states"][-1][:, crop_start:]
|
||||
|
||||
batch_size = len(prompt)
|
||||
|
||||
attention_mask = inputs["attention_mask"][:, crop_start:]
|
||||
cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0)
|
||||
cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to(dtype=torch.int32)
|
||||
# cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to(dtype=torch.int32)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(dtype=torch.int32)
|
||||
|
||||
embeds = torch.cat([embeds[i].unsqueeze(dim=0).repeat(num_videos_per_prompt, 1, 1) for i in range(batch_size)], dim=0)
|
||||
|
||||
return embeds.to(dtype), cu_seqlens
|
||||
@@ -354,8 +357,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
pooled_embed = self.text_encoder_2(**inputs)["pooler_output"]
|
||||
pooled_embed = self.text_encoder_2(**inputs)["pooler_output"]
|
||||
|
||||
# duplicate for each generation per prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
Reference in New Issue
Block a user