mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
rewrite Kandinsky5T2VPipeline to diffusers style
This commit is contained in:
@@ -75,6 +75,101 @@ EXAMPLE_DOC_STRING = """
|
||||
```
|
||||
"""
|
||||
|
||||
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import html
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from transformers import Qwen2TokenizerFast, Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, AutoProcessor, CLIPTextModel, CLIPTokenizer
|
||||
import torchvision
|
||||
from torchvision.transforms import ToPILImage
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...loaders import KandinskyLoraLoaderMixin
|
||||
from ...models import AutoencoderKLHunyuanVideo
|
||||
from ...models.transformers import Kandinsky5Transformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import KandinskyPipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import Kandinsky5T2VPipeline
|
||||
>>> from diffusers.utils import export_to_video
|
||||
|
||||
>>> pipe = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V")
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> prompt = "A cat and a dog baking a cake together in a kitchen."
|
||||
>>> negative_prompt = "Bright tones, overexposed, static, blurred details"
|
||||
|
||||
>>> output = pipe(
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... height=512,
|
||||
... width=768,
|
||||
... num_frames=25,
|
||||
... num_inference_steps=50,
|
||||
... guidance_scale=5.0,
|
||||
... ).frames[0]
|
||||
>>> export_to_video(output, "output.mp4", fps=6)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
if is_ftfy_available():
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def prompt_clean(text):
|
||||
text = whitespace_clean(basic_clean(text))
|
||||
return text
|
||||
|
||||
|
||||
class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
r"""
|
||||
@@ -96,9 +191,11 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
Frozen CLIP text encoder.
|
||||
tokenizer_2 ([`CLIPTokenizer`]):
|
||||
Tokenizer for CLIP.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
@@ -125,6 +222,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
|
||||
self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio
|
||||
self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
def _encode_prompt_qwen(
|
||||
self,
|
||||
@@ -132,9 +230,12 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
device: Optional[torch.device] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 256,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt = [prompt_clean(p) for p in prompt]
|
||||
|
||||
# Kandinsky specific prompt template
|
||||
prompt_template = "\n".join([
|
||||
@@ -180,16 +281,19 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
embeds = embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
embeds = embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return embeds, cu_seqlens
|
||||
return embeds.to(dtype), cu_seqlens
|
||||
|
||||
def _encode_prompt_clip(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder_2.dtype
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt = [prompt_clean(p) for p in prompt]
|
||||
|
||||
inputs = self.tokenizer_2(
|
||||
prompt,
|
||||
@@ -208,7 +312,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
pooled_embed = pooled_embed.repeat(1, num_videos_per_prompt, 1)
|
||||
pooled_embed = pooled_embed.view(batch_size * num_videos_per_prompt, -1)
|
||||
|
||||
return pooled_embed
|
||||
return pooled_embed.to(dtype)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
@@ -216,34 +320,151 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
max_sequence_length (`int`, *optional*, defaults to 512):
|
||||
Maximum sequence length for text encoding.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt_embeds, prompt_cu_seqlens = self._encode_prompt_qwen(prompt, device, num_videos_per_prompt)
|
||||
pooled_embed = self._encode_prompt_clip(prompt, device, num_videos_per_prompt)
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
dtype=dtype,
|
||||
)
|
||||
prompt_embeds_clip = self._encode_prompt_clip(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
dtype=dtype,
|
||||
)
|
||||
else:
|
||||
prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = prompt_embeds
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
negative_prompt_embeds, negative_cu_seqlens = self._encode_prompt_qwen(negative_prompt, device, num_videos_per_prompt)
|
||||
negative_pooled_embed = self._encode_prompt_clip(negative_prompt, device, num_videos_per_prompt)
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds_qwen, negative_cu_seqlens = self._encode_prompt_qwen(
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
dtype=dtype,
|
||||
)
|
||||
negative_prompt_embeds_clip = self._encode_prompt_clip(
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
dtype=dtype,
|
||||
)
|
||||
else:
|
||||
negative_prompt_embeds = None
|
||||
negative_pooled_embed = None
|
||||
negative_prompt_embeds_qwen = None
|
||||
negative_prompt_embeds_clip = None
|
||||
negative_cu_seqlens = None
|
||||
|
||||
text_embeds = {
|
||||
"text_embeds": prompt_embeds,
|
||||
"pooled_embed": pooled_embed,
|
||||
prompt_embeds_dict = {
|
||||
"text_embeds": prompt_embeds_qwen,
|
||||
"pooled_embed": prompt_embeds_clip,
|
||||
}
|
||||
negative_text_embeds = {
|
||||
"text_embeds": negative_prompt_embeds,
|
||||
"pooled_embed": negative_pooled_embed,
|
||||
negative_prompt_embeds_dict = {
|
||||
"text_embeds": negative_prompt_embeds_qwen,
|
||||
"pooled_embed": negative_prompt_embeds_clip,
|
||||
} if do_classifier_free_guidance else None
|
||||
|
||||
return text_embeds, negative_text_embeds, prompt_cu_seqlens, negative_cu_seqlens
|
||||
return prompt_embeds_dict, negative_prompt_embeds_dict, prompt_cu_seqlens, negative_cu_seqlens
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif negative_prompt is not None and (
|
||||
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
|
||||
):
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
@@ -252,34 +473,31 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
num_frames: int = 81,
|
||||
visual_cond: bool = False,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
num_latent_frames = latents.shape[1]
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
else:
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
shape = (
|
||||
batch_size,
|
||||
num_latent_frames,
|
||||
int(height) // self.vae_scale_factor_spatial,
|
||||
int(width) // self.vae_scale_factor_spatial,
|
||||
num_channels_latents,
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
shape = (
|
||||
batch_size,
|
||||
num_latent_frames,
|
||||
int(height) // self.vae_scale_factor_spatial,
|
||||
int(width) // self.vae_scale_factor_spatial,
|
||||
num_channels_latents,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
if visual_cond:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
if self.transformer.visual_cond:
|
||||
# For visual conditioning, concatenate with zeros and mask
|
||||
visual_cond = torch.zeros_like(latents)
|
||||
visual_cond_mask = torch.zeros(
|
||||
@@ -291,26 +509,46 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 512,
|
||||
width: int = 768,
|
||||
num_frames: int = 121,
|
||||
num_frames: int = 25,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
scheduler_scale: float = 10.0,
|
||||
num_videos_per_prompt: int = 1,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -318,9 +556,10 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the video generation.
|
||||
The prompt or prompts to guide the video generation. If not defined, pass `prompt_embeds` instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to avoid during video generation.
|
||||
The prompt or prompts to avoid during video generation. If not defined, pass `negative_prompt_embeds`
|
||||
instead. Ignored when not using guidance (`guidance_scale` < `1`).
|
||||
height (`int`, defaults to `512`):
|
||||
The height in pixels of the generated video.
|
||||
width (`int`, defaults to `768`):
|
||||
@@ -335,82 +574,109 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
Scale factor for the custom flow matching scheduler.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of videos to generate per prompt.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A torch generator to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated video.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`KandinskyPipelineOutput`].
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function that is called at the end of each denoising step.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function.
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum sequence length for text encoding.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~KandinskyPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`KandinskyPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
||||
the first element is a list with the generated images and the second element is a list of `bool`s
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
the first element is a list with the generated images.
|
||||
"""
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 0. Reset embeddings dtype
|
||||
self.transformer.time_embeddings.reset_dtype()
|
||||
self.transformer.text_rope_embeddings.reset_dtype()
|
||||
self.transformer.visual_rope_embeddings.reset_dtype()
|
||||
|
||||
dtype = self.transformer.dtype
|
||||
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
else:
|
||||
batch_size = len(prompt)
|
||||
|
||||
device = self._execution_device
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
if num_frames % self.vae_scale_factor_temporal != 1:
|
||||
logger.warning(
|
||||
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
|
||||
)
|
||||
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
||||
num_frames = max(num_frames, 1)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
text_embeds, negative_text_embeds, prompt_cu_seqlens, negative_cu_seqlens = self.encode_prompt(
|
||||
self._guidance_scale = guidance_scale
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
dtype = self.transformer.dtype
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0]
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds_dict, negative_prompt_embeds_dict, prompt_cu_seqlens, negative_cu_seqlens = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = 16
|
||||
latents = self.prepare_latents(
|
||||
batch_size=batch_size * num_videos_per_prompt,
|
||||
num_channels_latents=16,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
visual_cond=self.transformer.visual_cond,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
visual_cond = latents[:, :, :, :, 16:]
|
||||
|
||||
# 6. Prepare rope positions
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
visual_rope_pos = [
|
||||
torch.arange(num_frames // 4 + 1, device=device),
|
||||
torch.arange(height // 8 // 2, device=device),
|
||||
torch.arange(width // 8 // 2, device=device),
|
||||
torch.arange(num_latent_frames, device=device),
|
||||
torch.arange(height // self.vae_scale_factor_spatial // 2, device=device),
|
||||
torch.arange(width // self.vae_scale_factor_spatial // 2, device=device),
|
||||
]
|
||||
|
||||
text_rope_pos = torch.arange(prompt_cu_seqlens[-1].item(), device=device)
|
||||
@@ -421,52 +687,72 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
else None
|
||||
)
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
timestep = t.unsqueeze(0).flatten()
|
||||
|
||||
with torch.autocast(device_type="cuda", dtype=dtype):
|
||||
pred_velocity = self.transformer(
|
||||
hidden_states=latents,
|
||||
encoder_hidden_states=text_embeds["text_embeds"],
|
||||
pooled_projections=text_embeds["pooled_embed"],
|
||||
timestep=timestep,
|
||||
|
||||
|
||||
# Predict noise residual
|
||||
# with torch.autocast(device_type="cuda", dtype=dtype):
|
||||
pred_velocity = self.transformer(
|
||||
hidden_states=latents.to(dtype),
|
||||
encoder_hidden_states=prompt_embeds_dict["text_embeds"].to(dtype),
|
||||
pooled_projections=prompt_embeds_dict["pooled_embed"].to(dtype),
|
||||
timestep=timestep.to(dtype),
|
||||
visual_rope_pos=visual_rope_pos,
|
||||
text_rope_pos=text_rope_pos,
|
||||
scale_factor=(1, 2, 2),
|
||||
sparse_params=None,
|
||||
return_dict=True
|
||||
).sample
|
||||
|
||||
if self.do_classifier_free_guidance and negative_prompt_embeds_dict is not None:
|
||||
uncond_pred_velocity = self.transformer(
|
||||
hidden_states=latents.to(dtype),
|
||||
encoder_hidden_states=negative_prompt_embeds_dict["text_embeds"].to(dtype),
|
||||
pooled_projections=negative_prompt_embeds_dict["pooled_embed"].to(dtype),
|
||||
timestep=timestep.to(dtype),
|
||||
visual_rope_pos=visual_rope_pos,
|
||||
text_rope_pos=text_rope_pos,
|
||||
scale_factor=(1, 2, 2),
|
||||
text_rope_pos=negative_text_rope_pos,
|
||||
scale_factor=(1, 2, 2),
|
||||
sparse_params=None,
|
||||
return_dict=True
|
||||
).sample
|
||||
|
||||
if guidance_scale > 1.0 and negative_text_embeds is not None:
|
||||
uncond_pred_velocity = self.transformer(
|
||||
hidden_states=latents,
|
||||
encoder_hidden_states=negative_text_embeds["text_embeds"],
|
||||
pooled_projections=negative_text_embeds["pooled_embed"],
|
||||
timestep=timestep,
|
||||
visual_rope_pos=visual_rope_pos,
|
||||
text_rope_pos=negative_text_rope_pos,
|
||||
scale_factor=(1, 2, 2),
|
||||
sparse_params=None,
|
||||
return_dict=True
|
||||
).sample
|
||||
|
||||
pred_velocity = uncond_pred_velocity + guidance_scale * (
|
||||
pred_velocity - uncond_pred_velocity
|
||||
)
|
||||
pred_velocity = uncond_pred_velocity + guidance_scale * (
|
||||
pred_velocity - uncond_pred_velocity
|
||||
)
|
||||
|
||||
latents[:, :, :, :, :16] = self.scheduler.step(pred_velocity, t, latents[:, :, :, :, :16], return_dict=False)[0]
|
||||
# Compute previous sample
|
||||
latents[:, :, :, :, :16] = self.scheduler.step(
|
||||
pred_velocity, t, latents[:, :, :, :, :16], return_dict=False
|
||||
)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, timestep, callback_kwargs)
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
|
||||
prompt_embeds_dict = callback_outputs.pop("prompt_embeds", prompt_embeds_dict)
|
||||
negative_prompt_embeds_dict = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds_dict)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
# 8. Post-processing
|
||||
latents = latents[:, :, :, :, :16]
|
||||
|
||||
# 9. Decode latents to video
|
||||
@@ -477,26 +763,23 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
batch_size,
|
||||
num_videos_per_prompt,
|
||||
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
||||
height // 8,
|
||||
width // 8,
|
||||
height // self.vae_scale_factor_spatial,
|
||||
width // self.vae_scale_factor_spatial,
|
||||
16,
|
||||
)
|
||||
video = video.permute(0, 1, 5, 2, 3, 4) # [batch, num_videos, channels, frames, height, width]
|
||||
video = video.reshape(batch_size * num_videos_per_prompt, 16, (num_frames - 1) // self.vae_scale_factor_temporal + 1, height // 8, width // 8)
|
||||
video = video.reshape(
|
||||
batch_size * num_videos_per_prompt,
|
||||
16,
|
||||
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
||||
height // self.vae_scale_factor_spatial,
|
||||
width // self.vae_scale_factor_spatial
|
||||
)
|
||||
|
||||
# Normalize and decode
|
||||
video = video / self.vae.config.scaling_factor
|
||||
video = self.vae.decode(video).sample
|
||||
video = ((video.clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8)
|
||||
# Convert to output format
|
||||
if output_type == "pil":
|
||||
if num_frames == 1:
|
||||
# Single image
|
||||
video = [ToPILImage()(frame.squeeze(1)) for frame in video]
|
||||
else:
|
||||
# Video frames
|
||||
video = [video[i] for i in range(video.shape[0])]
|
||||
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
|
||||
Reference in New Issue
Block a user