1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

rewrite Kandinsky5T2VPipeline to diffusers style

This commit is contained in:
leffff
2025-10-10 14:39:59 +00:00
parent 0bd738f52b
commit c8f3a36fba

View File

@@ -75,6 +75,101 @@ EXAMPLE_DOC_STRING = """
```
"""
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html
from typing import Any, Callable, Dict, List, Optional, Union
import regex as re
import torch
from transformers import Qwen2TokenizerFast, Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, AutoProcessor, CLIPTextModel, CLIPTokenizer
import torchvision
from torchvision.transforms import ToPILImage
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import KandinskyLoraLoaderMixin
from ...models import AutoencoderKLHunyuanVideo
from ...models.transformers import Kandinsky5Transformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import KandinskyPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_ftfy_available():
import ftfy
logger = logging.get_logger(__name__)
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import Kandinsky5T2VPipeline
>>> from diffusers.utils import export_to_video
>>> pipe = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V")
>>> pipe = pipe.to("cuda")
>>> prompt = "A cat and a dog baking a cake together in a kitchen."
>>> negative_prompt = "Bright tones, overexposed, static, blurred details"
>>> output = pipe(
... prompt=prompt,
... negative_prompt=negative_prompt,
... height=512,
... width=768,
... num_frames=25,
... num_inference_steps=50,
... guidance_scale=5.0,
... ).frames[0]
>>> export_to_video(output, "output.mp4", fps=6)
```
"""
def basic_clean(text):
if is_ftfy_available():
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def prompt_clean(text):
text = whitespace_clean(basic_clean(text))
return text
class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
r"""
@@ -96,9 +191,11 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
Frozen CLIP text encoder.
tokenizer_2 ([`CLIPTokenizer`]):
Tokenizer for CLIP.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
@@ -125,6 +222,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio
self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
def _encode_prompt_qwen(
self,
@@ -132,9 +230,12 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
device: Optional[torch.device] = None,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 256,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [prompt_clean(p) for p in prompt]
# Kandinsky specific prompt template
prompt_template = "\n".join([
@@ -180,16 +281,19 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
embeds = embeds.repeat(1, num_videos_per_prompt, 1)
embeds = embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return embeds, cu_seqlens
return embeds.to(dtype), cu_seqlens
def _encode_prompt_clip(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
num_videos_per_prompt: int = 1,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder_2.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [prompt_clean(p) for p in prompt]
inputs = self.tokenizer_2(
prompt,
@@ -208,7 +312,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
pooled_embed = pooled_embed.repeat(1, num_videos_per_prompt, 1)
pooled_embed = pooled_embed.view(batch_size * num_videos_per_prompt, -1)
return pooled_embed
return pooled_embed.to(dtype)
def encode_prompt(
self,
@@ -216,34 +320,151 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
max_sequence_length (`int`, *optional*, defaults to 512):
Maximum sequence length for text encoding.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt_embeds, prompt_cu_seqlens = self._encode_prompt_qwen(prompt, device, num_videos_per_prompt)
pooled_embed = self._encode_prompt_clip(prompt, device, num_videos_per_prompt)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0]
if do_classifier_free_guidance:
if prompt_embeds is None:
prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen(
prompt=prompt,
device=device,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
dtype=dtype,
)
prompt_embeds_clip = self._encode_prompt_clip(
prompt=prompt,
device=device,
num_videos_per_prompt=num_videos_per_prompt,
dtype=dtype,
)
else:
prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = prompt_embeds
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_embeds, negative_cu_seqlens = self._encode_prompt_qwen(negative_prompt, device, num_videos_per_prompt)
negative_pooled_embed = self._encode_prompt_clip(negative_prompt, device, num_videos_per_prompt)
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds_qwen, negative_cu_seqlens = self._encode_prompt_qwen(
prompt=negative_prompt,
device=device,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
dtype=dtype,
)
negative_prompt_embeds_clip = self._encode_prompt_clip(
prompt=negative_prompt,
device=device,
num_videos_per_prompt=num_videos_per_prompt,
dtype=dtype,
)
else:
negative_prompt_embeds = None
negative_pooled_embed = None
negative_prompt_embeds_qwen = None
negative_prompt_embeds_clip = None
negative_cu_seqlens = None
text_embeds = {
"text_embeds": prompt_embeds,
"pooled_embed": pooled_embed,
prompt_embeds_dict = {
"text_embeds": prompt_embeds_qwen,
"pooled_embed": prompt_embeds_clip,
}
negative_text_embeds = {
"text_embeds": negative_prompt_embeds,
"pooled_embed": negative_pooled_embed,
negative_prompt_embeds_dict = {
"text_embeds": negative_prompt_embeds_qwen,
"pooled_embed": negative_prompt_embeds_clip,
} if do_classifier_free_guidance else None
return text_embeds, negative_text_embeds, prompt_cu_seqlens, negative_cu_seqlens
return prompt_embeds_dict, negative_prompt_embeds_dict, prompt_cu_seqlens, negative_cu_seqlens
def check_inputs(
self,
prompt,
negative_prompt,
height,
width,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif negative_prompt is not None and (
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
def prepare_latents(
self,
@@ -252,34 +473,31 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
height: int = 480,
width: int = 832,
num_frames: int = 81,
visual_cond: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
num_latent_frames = latents.shape[1]
latents = latents.to(device=device, dtype=dtype)
return latents.to(device=device, dtype=dtype)
else:
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
shape = (
batch_size,
num_latent_frames,
int(height) // self.vae_scale_factor_spatial,
int(width) // self.vae_scale_factor_spatial,
num_channels_latents,
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
shape = (
batch_size,
num_latent_frames,
int(height) // self.vae_scale_factor_spatial,
int(width) // self.vae_scale_factor_spatial,
num_channels_latents,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if visual_cond:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if self.transformer.visual_cond:
# For visual conditioning, concatenate with zeros and mask
visual_cond = torch.zeros_like(latents)
visual_cond_mask = torch.zeros(
@@ -291,26 +509,46 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]],
prompt: Union[str, List[str]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 512,
width: int = 768,
num_frames: int = 121,
num_frames: int = 25,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
scheduler_scale: float = 10.0,
num_videos_per_prompt: int = 1,
generator: Optional[torch.Generator] = None,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
**kwargs,
):
r"""
@@ -318,9 +556,10 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the video generation.
The prompt or prompts to guide the video generation. If not defined, pass `prompt_embeds` instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to avoid during video generation.
The prompt or prompts to avoid during video generation. If not defined, pass `negative_prompt_embeds`
instead. Ignored when not using guidance (`guidance_scale` < `1`).
height (`int`, defaults to `512`):
The height in pixels of the generated video.
width (`int`, defaults to `768`):
@@ -335,82 +574,109 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
Scale factor for the custom flow matching scheduler.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator`, *optional*):
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A torch generator to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated video.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`KandinskyPipelineOutput`].
callback_on_step_end (`Callable`, *optional*):
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function that is called at the end of each denoising step.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function.
max_sequence_length (`int`, defaults to `512`):
The maximum sequence length for text encoding.
Examples:
Returns:
[`~KandinskyPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`KandinskyPipelineOutput`] is returned, otherwise a `tuple` is returned where
the first element is a list with the generated images and the second element is a list of `bool`s
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
the first element is a list with the generated images.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 0. Reset embeddings dtype
self.transformer.time_embeddings.reset_dtype()
self.transformer.text_rope_embeddings.reset_dtype()
self.transformer.visual_rope_embeddings.reset_dtype()
dtype = self.transformer.dtype
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
negative_prompt,
height,
width,
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
)
if isinstance(prompt, str):
batch_size = 1
else:
batch_size = len(prompt)
device = self._execution_device
do_classifier_free_guidance = guidance_scale > 1.0
if num_frames % self.vae_scale_factor_temporal != 1:
logger.warning(
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
)
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
num_frames = max(num_frames, 1)
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
text_embeds, negative_text_embeds, prompt_cu_seqlens, negative_cu_seqlens = self.encode_prompt(
self._guidance_scale = guidance_scale
self._interrupt = False
device = self._execution_device
dtype = self.transformer.dtype
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0]
# 3. Encode input prompt
prompt_embeds_dict, negative_prompt_embeds_dict, prompt_cu_seqlens, negative_cu_seqlens = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = 16
latents = self.prepare_latents(
batch_size=batch_size * num_videos_per_prompt,
num_channels_latents=16,
height=height,
width=width,
num_frames=num_frames,
visual_cond=self.transformer.visual_cond,
dtype=dtype,
device=device,
generator=generator,
latents=latents,
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
dtype,
device,
generator,
latents,
)
visual_cond = latents[:, :, :, :, 16:]
# 6. Prepare rope positions
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
visual_rope_pos = [
torch.arange(num_frames // 4 + 1, device=device),
torch.arange(height // 8 // 2, device=device),
torch.arange(width // 8 // 2, device=device),
torch.arange(num_latent_frames, device=device),
torch.arange(height // self.vae_scale_factor_spatial // 2, device=device),
torch.arange(width // self.vae_scale_factor_spatial // 2, device=device),
]
text_rope_pos = torch.arange(prompt_cu_seqlens[-1].item(), device=device)
@@ -421,52 +687,72 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
else None
)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
timestep = t.unsqueeze(0).flatten()
with torch.autocast(device_type="cuda", dtype=dtype):
pred_velocity = self.transformer(
hidden_states=latents,
encoder_hidden_states=text_embeds["text_embeds"],
pooled_projections=text_embeds["pooled_embed"],
timestep=timestep,
# Predict noise residual
# with torch.autocast(device_type="cuda", dtype=dtype):
pred_velocity = self.transformer(
hidden_states=latents.to(dtype),
encoder_hidden_states=prompt_embeds_dict["text_embeds"].to(dtype),
pooled_projections=prompt_embeds_dict["pooled_embed"].to(dtype),
timestep=timestep.to(dtype),
visual_rope_pos=visual_rope_pos,
text_rope_pos=text_rope_pos,
scale_factor=(1, 2, 2),
sparse_params=None,
return_dict=True
).sample
if self.do_classifier_free_guidance and negative_prompt_embeds_dict is not None:
uncond_pred_velocity = self.transformer(
hidden_states=latents.to(dtype),
encoder_hidden_states=negative_prompt_embeds_dict["text_embeds"].to(dtype),
pooled_projections=negative_prompt_embeds_dict["pooled_embed"].to(dtype),
timestep=timestep.to(dtype),
visual_rope_pos=visual_rope_pos,
text_rope_pos=text_rope_pos,
scale_factor=(1, 2, 2),
text_rope_pos=negative_text_rope_pos,
scale_factor=(1, 2, 2),
sparse_params=None,
return_dict=True
).sample
if guidance_scale > 1.0 and negative_text_embeds is not None:
uncond_pred_velocity = self.transformer(
hidden_states=latents,
encoder_hidden_states=negative_text_embeds["text_embeds"],
pooled_projections=negative_text_embeds["pooled_embed"],
timestep=timestep,
visual_rope_pos=visual_rope_pos,
text_rope_pos=negative_text_rope_pos,
scale_factor=(1, 2, 2),
sparse_params=None,
return_dict=True
).sample
pred_velocity = uncond_pred_velocity + guidance_scale * (
pred_velocity - uncond_pred_velocity
)
pred_velocity = uncond_pred_velocity + guidance_scale * (
pred_velocity - uncond_pred_velocity
)
latents[:, :, :, :, :16] = self.scheduler.step(pred_velocity, t, latents[:, :, :, :, :16], return_dict=False)[0]
# Compute previous sample
latents[:, :, :, :, :16] = self.scheduler.step(
pred_velocity, t, latents[:, :, :, :, :16], return_dict=False
)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, timestep, callback_kwargs)
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds_dict = callback_outputs.pop("prompt_embeds", prompt_embeds_dict)
negative_prompt_embeds_dict = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds_dict)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# 8. Post-processing
latents = latents[:, :, :, :, :16]
# 9. Decode latents to video
@@ -477,26 +763,23 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
batch_size,
num_videos_per_prompt,
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
height // 8,
width // 8,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
16,
)
video = video.permute(0, 1, 5, 2, 3, 4) # [batch, num_videos, channels, frames, height, width]
video = video.reshape(batch_size * num_videos_per_prompt, 16, (num_frames - 1) // self.vae_scale_factor_temporal + 1, height // 8, width // 8)
video = video.reshape(
batch_size * num_videos_per_prompt,
16,
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial
)
# Normalize and decode
video = video / self.vae.config.scaling_factor
video = self.vae.decode(video).sample
video = ((video.clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8)
# Convert to output format
if output_type == "pil":
if num_frames == 1:
# Single image
video = [ToPILImage()(frame.squeeze(1)) for frame in video]
else:
# Video frames
video = [video[i] for i in range(video.shape[0])]
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
video = latents