1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Kandinsky 5 is finally in Diffusers! (#12478)

* add kandinsky5 transformer pipeline first version

---------

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Charles <charles@huggingface.co>
This commit is contained in:
Lev Novitskiy
2025-10-18 07:34:30 +03:00
committed by GitHub
parent 1b456bd5d5
commit 23ebbb4bc8
13 changed files with 1957 additions and 0 deletions

View File

@@ -107,6 +107,9 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
[[autodoc]] loaders.lora_pipeline.QwenImageLoraLoaderMixin
## KandinskyLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.KandinskyLoraLoaderMixin
## LoraBaseMixin
[[autodoc]] loaders.lora_base.LoraBaseMixin

View File

@@ -220,6 +220,7 @@ else:
"HunyuanVideoTransformer3DModel",
"I2VGenXLUNet",
"Kandinsky3UNet",
"Kandinsky5Transformer3DModel",
"LatteTransformer3DModel",
"LTXVideoTransformer3DModel",
"Lumina2Transformer2DModel",
@@ -474,6 +475,7 @@ else:
"ImageTextPipelineOutput",
"Kandinsky3Img2ImgPipeline",
"Kandinsky3Pipeline",
"Kandinsky5T2VPipeline",
"KandinskyCombinedPipeline",
"KandinskyImg2ImgCombinedPipeline",
"KandinskyImg2ImgPipeline",
@@ -912,6 +914,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanVideoTransformer3DModel,
I2VGenXLUNet,
Kandinsky3UNet,
Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
@@ -1136,6 +1139,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ImageTextPipelineOutput,
Kandinsky3Img2ImgPipeline,
Kandinsky3Pipeline,
Kandinsky5T2VPipeline,
KandinskyCombinedPipeline,
KandinskyImg2ImgCombinedPipeline,
KandinskyImg2ImgPipeline,

View File

@@ -77,6 +77,7 @@ if is_torch_available():
"SanaLoraLoaderMixin",
"Lumina2LoraLoaderMixin",
"WanLoraLoaderMixin",
"KandinskyLoraLoaderMixin",
"HiDreamImageLoraLoaderMixin",
"SkyReelsV2LoraLoaderMixin",
"QwenImageLoraLoaderMixin",
@@ -115,6 +116,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxLoraLoaderMixin,
HiDreamImageLoraLoaderMixin,
HunyuanVideoLoraLoaderMixin,
KandinskyLoraLoaderMixin,
LoraLoaderMixin,
LTXVideoLoraLoaderMixin,
Lumina2LoraLoaderMixin,

View File

@@ -3639,6 +3639,291 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
super().unfuse_lora(components=components, **kwargs)
class KandinskyLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`Kandinsky5Transformer3DModel`],
"""
_lora_loadable_modules = ["transformer"]
transformer_name = TRANSFORMER_NAME
@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
Return state dict for lora weights and the network alphas.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* of a pretrained model hosted on the Hub.
- A path to a *directory* containing the model weights.
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository.
weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file.
use_safetensors (`bool`, *optional*):
Whether to use safetensors for loading.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata.
"""
# Load the main state dict first which has the LoRA layers
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer`
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model.
hotswap (`bool`, *optional*):
Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
kwargs (`dict`, *optional*):
See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
# Load LoRA into transformer
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
def load_lora_into_transformer(
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
Load the LoRA layers specified in `state_dict` into `transformer`.
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters.
transformer (`Kandinsky5Transformer3DModel`):
The transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights.
hotswap (`bool`, *optional*):
See [`~loaders.KandinskyLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata.
"""
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata=None,
):
r"""
Save the LoRA parameters corresponding to the transformer and text encoders.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to.
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `transformer`.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process.
save_function (`Callable`):
The function to use to save the state dictionary.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer.
"""
lora_layers = {}
lora_metadata = {}
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers`")
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
def fuse_lora(
self,
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
Args:
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing.
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing.
Example:
```py
from diffusers import Kandinsky5T2VPipeline
pipeline = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V")
pipeline.load_lora_weights("path/to/lora.safetensors")
pipeline.fuse_lora(lora_scale=0.7)
```
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
Reverses the effect of [`pipe.fuse_lora()`].
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
"""
super().unfuse_lora(components=components, **kwargs)
class WanLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].

View File

@@ -91,6 +91,7 @@ if is_torch_available():
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
@@ -182,6 +183,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiT2DModel,
HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel,
Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,

View File

@@ -27,6 +27,7 @@ if is_torch_available():
from .transformer_hidream_image import HiDreamImageTransformer2DModel
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
from .transformer_kandinsky import Kandinsky5Transformer3DModel
from .transformer_ltx import LTXVideoTransformer3DModel
from .transformer_lumina2 import Lumina2Transformer2DModel
from .transformer_mochi import MochiTransformer3DModel

View File

@@ -0,0 +1,667 @@
# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import (
logging,
)
from ..attention import AttentionMixin, AttentionModuleMixin
from ..attention_dispatch import _CAN_USE_FLEX_ATTN, dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
logger = logging.get_logger(__name__)
def get_freqs(dim, max_period=10000.0):
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim)
return freqs
def fractal_flatten(x, rope, shape, block_mask=False):
if block_mask:
pixel_size = 8
x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=1)
rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=1)
x = x.flatten(1, 2)
rope = rope.flatten(1, 2)
else:
x = x.flatten(1, 3)
rope = rope.flatten(1, 3)
return x, rope
def fractal_unflatten(x, shape, block_mask=False):
if block_mask:
pixel_size = 8
x = x.reshape(x.shape[0], -1, pixel_size**2, *x.shape[2:])
x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=1)
else:
x = x.reshape(*shape, *x.shape[2:])
return x
def local_patching(x, shape, group_size, dim=0):
batch_size, duration, height, width = shape
g1, g2, g3 = group_size
x = x.reshape(
*x.shape[:dim],
duration // g1,
g1,
height // g2,
g2,
width // g3,
g3,
*x.shape[dim + 3 :],
)
x = x.permute(
*range(len(x.shape[:dim])),
dim,
dim + 2,
dim + 4,
dim + 1,
dim + 3,
dim + 5,
*range(dim + 6, len(x.shape)),
)
x = x.flatten(dim, dim + 2).flatten(dim + 1, dim + 3)
return x
def local_merge(x, shape, group_size, dim=0):
batch_size, duration, height, width = shape
g1, g2, g3 = group_size
x = x.reshape(
*x.shape[:dim],
duration // g1,
height // g2,
width // g3,
g1,
g2,
g3,
*x.shape[dim + 2 :],
)
x = x.permute(
*range(len(x.shape[:dim])),
dim,
dim + 3,
dim + 1,
dim + 4,
dim + 2,
dim + 5,
*range(dim + 6, len(x.shape)),
)
x = x.flatten(dim, dim + 1).flatten(dim + 1, dim + 2).flatten(dim + 2, dim + 3)
return x
def nablaT_v2(
q: Tensor,
k: Tensor,
sta: Tensor,
thr: float = 0.9,
):
if _CAN_USE_FLEX_ATTN:
from torch.nn.attention.flex_attention import BlockMask
else:
raise ValueError("Nabla attention is not supported with this version of PyTorch")
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
# Map estimation
B, h, S, D = q.shape
s1 = S // 64
qa = q.reshape(B, h, s1, 64, D).mean(-2)
ka = k.reshape(B, h, s1, 64, D).mean(-2).transpose(-2, -1)
map = qa @ ka
map = torch.softmax(map / math.sqrt(D), dim=-1)
# Map binarization
vals, inds = map.sort(-1)
cvals = vals.cumsum_(-1)
mask = (cvals >= 1 - thr).int()
mask = mask.gather(-1, inds.argsort(-1))
mask = torch.logical_or(mask, sta)
# BlockMask creation
kv_nb = mask.sum(-1).to(torch.int32)
kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32)
return BlockMask.from_kv_blocks(torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=64, mask_mod=None)
class Kandinsky5TimeEmbeddings(nn.Module):
def __init__(self, model_dim, time_dim, max_period=10000.0):
super().__init__()
assert model_dim % 2 == 0
self.model_dim = model_dim
self.max_period = max_period
self.freqs = get_freqs(self.model_dim // 2, self.max_period)
self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
self.activation = nn.SiLU()
self.out_layer = nn.Linear(time_dim, time_dim, bias=True)
@torch.autocast(device_type="cuda", dtype=torch.float32)
def forward(self, time):
args = torch.outer(time, self.freqs.to(device=time.device))
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
return time_embed
class Kandinsky5TextEmbeddings(nn.Module):
def __init__(self, text_dim, model_dim):
super().__init__()
self.in_layer = nn.Linear(text_dim, model_dim, bias=True)
self.norm = nn.LayerNorm(model_dim, elementwise_affine=True)
def forward(self, text_embed):
text_embed = self.in_layer(text_embed)
return self.norm(text_embed).type_as(text_embed)
class Kandinsky5VisualEmbeddings(nn.Module):
def __init__(self, visual_dim, model_dim, patch_size):
super().__init__()
self.patch_size = patch_size
self.in_layer = nn.Linear(math.prod(patch_size) * visual_dim, model_dim)
def forward(self, x):
batch_size, duration, height, width, dim = x.shape
x = (
x.view(
batch_size,
duration // self.patch_size[0],
self.patch_size[0],
height // self.patch_size[1],
self.patch_size[1],
width // self.patch_size[2],
self.patch_size[2],
dim,
)
.permute(0, 1, 3, 5, 2, 4, 6, 7)
.flatten(4, 7)
)
return self.in_layer(x)
class Kandinsky5RoPE1D(nn.Module):
def __init__(self, dim, max_pos=1024, max_period=10000.0):
super().__init__()
self.max_period = max_period
self.dim = dim
self.max_pos = max_pos
freq = get_freqs(dim // 2, max_period)
pos = torch.arange(max_pos, dtype=freq.dtype)
self.register_buffer("args", torch.outer(pos, freq), persistent=False)
def forward(self, pos):
args = self.args[pos]
cosine = torch.cos(args)
sine = torch.sin(args)
rope = torch.stack([cosine, -sine, sine, cosine], dim=-1)
rope = rope.view(*rope.shape[:-1], 2, 2)
return rope.unsqueeze(-4)
class Kandinsky5RoPE3D(nn.Module):
def __init__(self, axes_dims, max_pos=(128, 128, 128), max_period=10000.0):
super().__init__()
self.axes_dims = axes_dims
self.max_pos = max_pos
self.max_period = max_period
for i, (axes_dim, ax_max_pos) in enumerate(zip(axes_dims, max_pos)):
freq = get_freqs(axes_dim // 2, max_period)
pos = torch.arange(ax_max_pos, dtype=freq.dtype)
self.register_buffer(f"args_{i}", torch.outer(pos, freq), persistent=False)
def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)):
batch_size, duration, height, width = shape
args_t = self.args_0[pos[0]] / scale_factor[0]
args_h = self.args_1[pos[1]] / scale_factor[1]
args_w = self.args_2[pos[2]] / scale_factor[2]
args = torch.cat(
[
args_t.view(1, duration, 1, 1, -1).repeat(batch_size, 1, height, width, 1),
args_h.view(1, 1, height, 1, -1).repeat(batch_size, duration, 1, width, 1),
args_w.view(1, 1, 1, width, -1).repeat(batch_size, duration, height, 1, 1),
],
dim=-1,
)
cosine = torch.cos(args)
sine = torch.sin(args)
rope = torch.stack([cosine, -sine, sine, cosine], dim=-1)
rope = rope.view(*rope.shape[:-1], 2, 2)
return rope.unsqueeze(-4)
class Kandinsky5Modulation(nn.Module):
def __init__(self, time_dim, model_dim, num_params):
super().__init__()
self.activation = nn.SiLU()
self.out_layer = nn.Linear(time_dim, num_params * model_dim)
self.out_layer.weight.data.zero_()
self.out_layer.bias.data.zero_()
@torch.autocast(device_type="cuda", dtype=torch.float32)
def forward(self, x):
return self.out_layer(self.activation(x))
class Kandinsky5AttnProcessor:
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
def __call__(self, attn, hidden_states, encoder_hidden_states=None, rotary_emb=None, sparse_params=None):
# query, key, value = self.get_qkv(x)
query = attn.to_query(hidden_states)
if encoder_hidden_states is not None:
key = attn.to_key(encoder_hidden_states)
value = attn.to_value(encoder_hidden_states)
shape, cond_shape = query.shape[:-1], key.shape[:-1]
query = query.reshape(*shape, attn.num_heads, -1)
key = key.reshape(*cond_shape, attn.num_heads, -1)
value = value.reshape(*cond_shape, attn.num_heads, -1)
else:
key = attn.to_key(hidden_states)
value = attn.to_value(hidden_states)
shape = query.shape[:-1]
query = query.reshape(*shape, attn.num_heads, -1)
key = key.reshape(*shape, attn.num_heads, -1)
value = value.reshape(*shape, attn.num_heads, -1)
# query, key = self.norm_qk(query, key)
query = attn.query_norm(query.float()).type_as(query)
key = attn.key_norm(key.float()).type_as(key)
def apply_rotary(x, rope):
x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32)
x_out = (rope * x_).sum(dim=-1)
return x_out.reshape(*x.shape).to(torch.bfloat16)
if rotary_emb is not None:
query = apply_rotary(query, rotary_emb).type_as(query)
key = apply_rotary(key, rotary_emb).type_as(key)
if sparse_params is not None:
attn_mask = nablaT_v2(
query,
key,
sparse_params["sta_mask"],
thr=sparse_params["P"],
)
else:
attn_mask = None
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attn_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(-2, -1)
attn_out = attn.out_layer(hidden_states)
return attn_out
class Kandinsky5Attention(nn.Module, AttentionModuleMixin):
_default_processor_cls = Kandinsky5AttnProcessor
_available_processors = [
Kandinsky5AttnProcessor,
]
def __init__(self, num_channels, head_dim, processor=None):
super().__init__()
assert num_channels % head_dim == 0
self.num_heads = num_channels // head_dim
self.to_query = nn.Linear(num_channels, num_channels, bias=True)
self.to_key = nn.Linear(num_channels, num_channels, bias=True)
self.to_value = nn.Linear(num_channels, num_channels, bias=True)
self.query_norm = nn.RMSNorm(head_dim)
self.key_norm = nn.RMSNorm(head_dim)
self.out_layer = nn.Linear(num_channels, num_channels, bias=True)
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
sparse_params: Optional[torch.Tensor] = None,
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
quiet_attn_parameters = {}
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
if len(unused_kwargs) > 0:
logger.warning(
f"attention_processor_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
sparse_params=sparse_params,
rotary_emb=rotary_emb,
**kwargs,
)
class Kandinsky5FeedForward(nn.Module):
def __init__(self, dim, ff_dim):
super().__init__()
self.in_layer = nn.Linear(dim, ff_dim, bias=False)
self.activation = nn.GELU()
self.out_layer = nn.Linear(ff_dim, dim, bias=False)
def forward(self, x):
return self.out_layer(self.activation(self.in_layer(x)))
class Kandinsky5OutLayer(nn.Module):
def __init__(self, model_dim, time_dim, visual_dim, patch_size):
super().__init__()
self.patch_size = patch_size
self.modulation = Kandinsky5Modulation(time_dim, model_dim, 2)
self.norm = nn.LayerNorm(model_dim, elementwise_affine=False)
self.out_layer = nn.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True)
def forward(self, visual_embed, text_embed, time_embed):
shift, scale = torch.chunk(self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1)
visual_embed = (
self.norm(visual_embed.float()) * (scale.float()[:, None, None] + 1.0) + shift.float()[:, None, None]
).type_as(visual_embed)
x = self.out_layer(visual_embed)
batch_size, duration, height, width, _ = x.shape
x = (
x.view(
batch_size,
duration,
height,
width,
-1,
self.patch_size[0],
self.patch_size[1],
self.patch_size[2],
)
.permute(0, 1, 5, 2, 6, 3, 7, 4)
.flatten(1, 2)
.flatten(2, 3)
.flatten(3, 4)
)
return x
class Kandinsky5TransformerEncoderBlock(nn.Module):
def __init__(self, model_dim, time_dim, ff_dim, head_dim):
super().__init__()
self.text_modulation = Kandinsky5Modulation(time_dim, model_dim, 6)
self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
self.self_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor())
self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim)
def forward(self, x, time_embed, rope):
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1)
shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
out = (self.self_attention_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x)
out = self.self_attention(out, rotary_emb=rope)
x = (x.float() + gate.float() * out.float()).type_as(x)
shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
out = (self.feed_forward_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x)
out = self.feed_forward(out)
x = (x.float() + gate.float() * out.float()).type_as(x)
return x
class Kandinsky5TransformerDecoderBlock(nn.Module):
def __init__(self, model_dim, time_dim, ff_dim, head_dim):
super().__init__()
self.visual_modulation = Kandinsky5Modulation(time_dim, model_dim, 9)
self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
self.self_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor())
self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
self.cross_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor())
self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim)
def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params):
self_attn_params, cross_attn_params, ff_params = torch.chunk(
self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1
)
shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
visual_out = (self.self_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(
visual_embed
)
visual_out = self.self_attention(visual_out, rotary_emb=rope, sparse_params=sparse_params)
visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed)
shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1)
visual_out = (self.cross_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(
visual_embed
)
visual_out = self.cross_attention(visual_out, encoder_hidden_states=text_embed)
visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed)
shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
visual_out = (self.feed_forward_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(
visual_embed
)
visual_out = self.feed_forward(visual_out)
visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed)
return visual_embed
class Kandinsky5Transformer3DModel(
ModelMixin,
ConfigMixin,
PeftAdapterMixin,
FromOriginalModelMixin,
CacheMixin,
AttentionMixin,
):
"""
A 3D Diffusion Transformer model for video-like data.
"""
_repeated_blocks = [
"Kandinsky5TransformerEncoderBlock",
"Kandinsky5TransformerDecoderBlock",
]
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_visual_dim=4,
in_text_dim=3584,
in_text_dim2=768,
time_dim=512,
out_visual_dim=4,
patch_size=(1, 2, 2),
model_dim=2048,
ff_dim=5120,
num_text_blocks=2,
num_visual_blocks=32,
axes_dims=(16, 24, 24),
visual_cond=False,
attention_type: str = "regular",
attention_causal: bool = None,
attention_local: bool = None,
attention_glob: bool = None,
attention_window: int = None,
attention_P: float = None,
attention_wT: int = None,
attention_wW: int = None,
attention_wH: int = None,
attention_add_sta: bool = None,
attention_method: str = None,
):
super().__init__()
head_dim = sum(axes_dims)
self.in_visual_dim = in_visual_dim
self.model_dim = model_dim
self.patch_size = patch_size
self.visual_cond = visual_cond
self.attention_type = attention_type
visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim
# Initialize embeddings
self.time_embeddings = Kandinsky5TimeEmbeddings(model_dim, time_dim)
self.text_embeddings = Kandinsky5TextEmbeddings(in_text_dim, model_dim)
self.pooled_text_embeddings = Kandinsky5TextEmbeddings(in_text_dim2, time_dim)
self.visual_embeddings = Kandinsky5VisualEmbeddings(visual_embed_dim, model_dim, patch_size)
# Initialize positional embeddings
self.text_rope_embeddings = Kandinsky5RoPE1D(head_dim)
self.visual_rope_embeddings = Kandinsky5RoPE3D(axes_dims)
# Initialize transformer blocks
self.text_transformer_blocks = nn.ModuleList(
[Kandinsky5TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim) for _ in range(num_text_blocks)]
)
self.visual_transformer_blocks = nn.ModuleList(
[
Kandinsky5TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim)
for _ in range(num_visual_blocks)
]
)
# Initialize output layer
self.out_layer = Kandinsky5OutLayer(model_dim, time_dim, out_visual_dim, patch_size)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor, # x
encoder_hidden_states: torch.Tensor, # text_embed
timestep: torch.Tensor, # time
pooled_projections: torch.Tensor, # pooled_text_embed
visual_rope_pos: Tuple[int, int, int],
text_rope_pos: torch.LongTensor,
scale_factor: Tuple[float, float, float] = (1.0, 1.0, 1.0),
sparse_params: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Transformer2DModelOutput, torch.FloatTensor]:
"""
Forward pass of the Kandinsky5 3D Transformer.
Args:
hidden_states (`torch.FloatTensor`): Input visual states
encoder_hidden_states (`torch.FloatTensor`): Text embeddings
timestep (`torch.Tensor` or `float` or `int`): Current timestep
pooled_projections (`torch.FloatTensor`): Pooled text embeddings
visual_rope_pos (`Tuple[int, int, int]`): Position for visual RoPE
text_rope_pos (`torch.LongTensor`): Position for text RoPE
scale_factor (`Tuple[float, float, float]`, optional): Scale factor for RoPE
sparse_params (`Dict[str, Any]`, optional): Parameters for sparse attention
return_dict (`bool`, optional): Whether to return a dictionary
Returns:
[`~models.transformer_2d.Transformer2DModelOutput`] or `torch.FloatTensor`: The output of the transformer
"""
x = hidden_states
text_embed = encoder_hidden_states
time = timestep
pooled_text_embed = pooled_projections
text_embed = self.text_embeddings(text_embed)
time_embed = self.time_embeddings(time)
time_embed = time_embed + self.pooled_text_embeddings(pooled_text_embed)
visual_embed = self.visual_embeddings(x)
text_rope = self.text_rope_embeddings(text_rope_pos)
text_rope = text_rope.unsqueeze(dim=0)
for text_transformer_block in self.text_transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
text_embed = self._gradient_checkpointing_func(
text_transformer_block, text_embed, time_embed, text_rope
)
else:
text_embed = text_transformer_block(text_embed, time_embed, text_rope)
visual_shape = visual_embed.shape[:-1]
visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor)
to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False
visual_embed, visual_rope = fractal_flatten(visual_embed, visual_rope, visual_shape, block_mask=to_fractal)
for visual_transformer_block in self.visual_transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
visual_embed = self._gradient_checkpointing_func(
visual_transformer_block,
visual_embed,
text_embed,
time_embed,
visual_rope,
sparse_params,
)
else:
visual_embed = visual_transformer_block(
visual_embed, text_embed, time_embed, visual_rope, sparse_params
)
visual_embed = fractal_unflatten(visual_embed, visual_shape, block_mask=to_fractal)
x = self.out_layer(visual_embed, text_embed, time_embed)
if not return_dict:
return x
return Transformer2DModelOutput(sample=x)

View File

@@ -382,6 +382,7 @@ else:
"WuerstchenPriorPipeline",
]
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline", "WanVACEPipeline"]
_import_structure["kandinsky5"] = ["Kandinsky5T2VPipeline"]
_import_structure["skyreels_v2"] = [
"SkyReelsV2DiffusionForcingPipeline",
"SkyReelsV2DiffusionForcingImageToVideoPipeline",
@@ -671,6 +672,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Kandinsky3Img2ImgPipeline,
Kandinsky3Pipeline,
)
from .kandinsky5 import Kandinsky5T2VPipeline
from .latent_consistency_models import (
LatentConsistencyModelImg2ImgPipeline,
LatentConsistencyModelPipeline,

View File

@@ -0,0 +1,48 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_kandinsky"] = ["Kandinsky5T2VPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_kandinsky import Kandinsky5T2VPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,893 @@
# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html
from typing import Callable, Dict, List, Optional, Union
import regex as re
import torch
from torch.nn import functional as F
from transformers import CLIPTextModel, CLIPTokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import KandinskyLoraLoaderMixin
from ...models import AutoencoderKLHunyuanVideo
from ...models.transformers import Kandinsky5Transformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import KandinskyPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_ftfy_available():
import ftfy
logger = logging.get_logger(__name__)
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import Kandinsky5T2VPipeline
>>> from diffusers.utils import export_to_video
>>> # Available models:
>>> # ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers
>>> # ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-5s-Diffusers
>>> # ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers
>>> # ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-5s-Diffusers
>>> model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers"
>>> pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
>>> pipe = pipe.to("cuda")
>>> prompt = "A cat and a dog baking a cake together in a kitchen."
>>> negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
>>> output = pipe(
... prompt=prompt,
... negative_prompt=negative_prompt,
... height=512,
... width=768,
... num_frames=121,
... num_inference_steps=50,
... guidance_scale=5.0,
... ).frames[0]
>>> export_to_video(output, "output.mp4", fps=24, quality=9)
```
"""
def basic_clean(text):
"""Clean text using ftfy if available and unescape HTML entities."""
if is_ftfy_available():
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
"""Normalize whitespace in text by replacing multiple spaces with single space."""
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def prompt_clean(text):
"""Apply both basic cleaning and whitespace normalization to prompts."""
text = whitespace_clean(basic_clean(text))
return text
class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
r"""
Pipeline for text-to-video generation using Kandinsky 5.0.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
transformer ([`Kandinsky5Transformer3DModel`]):
Conditional Transformer to denoise the encoded video latents.
vae ([`AutoencoderKLHunyuanVideo`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
text_encoder ([`Qwen2_5_VLForConditionalGeneration`]):
Frozen text-encoder (Qwen2.5-VL).
tokenizer ([`AutoProcessor`]):
Tokenizer for Qwen2.5-VL.
text_encoder_2 ([`CLIPTextModel`]):
Frozen CLIP text encoder.
tokenizer_2 ([`CLIPTokenizer`]):
Tokenizer for CLIP.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_callback_tensor_inputs = [
"latents",
"prompt_embeds_qwen",
"prompt_embeds_clip",
"negative_prompt_embeds_qwen",
"negative_prompt_embeds_clip",
]
def __init__(
self,
transformer: Kandinsky5Transformer3DModel,
vae: AutoencoderKLHunyuanVideo,
text_encoder: Qwen2_5_VLForConditionalGeneration,
tokenizer: Qwen2VLProcessor,
text_encoder_2: CLIPTextModel,
tokenizer_2: CLIPTokenizer,
scheduler: FlowMatchEulerDiscreteScheduler,
):
super().__init__()
self.register_modules(
transformer=transformer,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
scheduler=scheduler,
)
self.prompt_template = "\n".join(
[
"<|im_start|>system\nYou are a promt engineer. Describe the video in detail.",
"Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.",
"Describe the location of the video, main characters or objects and their action.",
"Describe the dynamism of the video and presented actions.",
"Name the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.",
"Describe the visual effects, postprocessing and transitions if they are presented in the video.",
"Pay attention to the order of key actions shown in the scene.<|im_end|>",
"<|im_start|>user\n{}<|im_end|>",
]
)
self.prompt_template_encode_start_idx = 129
self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio
self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@staticmethod
def fast_sta_nabla(T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda") -> torch.Tensor:
"""
Create a sparse temporal attention (STA) mask for efficient video generation.
This method generates a mask that limits attention to nearby frames and spatial positions, reducing
computational complexity for video generation.
Args:
T (int): Number of temporal frames
H (int): Height in latent space
W (int): Width in latent space
wT (int): Temporal attention window size
wH (int): Height attention window size
wW (int): Width attention window size
device (str): Device to create tensor on
Returns:
torch.Tensor: Sparse attention mask of shape (T*H*W, T*H*W)
"""
l = torch.Tensor([T, H, W]).amax()
r = torch.arange(0, l, 1, dtype=torch.int16, device=device)
mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs()
sta_t, sta_h, sta_w = (
mat[:T, :T].flatten(),
mat[:H, :H].flatten(),
mat[:W, :W].flatten(),
)
sta_t = sta_t <= wT // 2
sta_h = sta_h <= wH // 2
sta_w = sta_w <= wW // 2
sta_hw = (sta_h.unsqueeze(1) * sta_w.unsqueeze(0)).reshape(H, H, W, W).transpose(1, 2).flatten()
sta = (sta_t.unsqueeze(1) * sta_hw.unsqueeze(0)).reshape(T, T, H * W, H * W).transpose(1, 2)
return sta.reshape(T * H * W, T * H * W)
def get_sparse_params(self, sample, device):
"""
Generate sparse attention parameters for the transformer based on sample dimensions.
This method computes the sparse attention configuration needed for efficient video processing in the
transformer model.
Args:
sample (torch.Tensor): Input sample tensor
device (torch.device): Device to place tensors on
Returns:
Dict: Dictionary containing sparse attention parameters
"""
assert self.transformer.config.patch_size[0] == 1
B, T, H, W, _ = sample.shape
T, H, W = (
T // self.transformer.config.patch_size[0],
H // self.transformer.config.patch_size[1],
W // self.transformer.config.patch_size[2],
)
if self.transformer.config.attention_type == "nabla":
sta_mask = self.fast_sta_nabla(
T,
H // 8,
W // 8,
self.transformer.config.attention_wT,
self.transformer.config.attention_wH,
self.transformer.config.attention_wW,
device=device,
)
sparse_params = {
"sta_mask": sta_mask.unsqueeze_(0).unsqueeze_(0),
"attention_type": self.transformer.config.attention_type,
"to_fractal": True,
"P": self.transformer.config.attention_P,
"wT": self.transformer.config.attention_wT,
"wW": self.transformer.config.attention_wW,
"wH": self.transformer.config.attention_wH,
"add_sta": self.transformer.config.attention_add_sta,
"visual_shape": (T, H, W),
"method": self.transformer.config.attention_method,
}
else:
sparse_params = None
return sparse_params
def _encode_prompt_qwen(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
max_sequence_length: int = 256,
dtype: Optional[torch.dtype] = None,
):
"""
Encode prompt using Qwen2.5-VL text encoder.
This method processes the input prompt through the Qwen2.5-VL model to generate text embeddings suitable for
video generation.
Args:
prompt (Union[str, List[str]]): Input prompt or list of prompts
device (torch.device): Device to run encoding on
num_videos_per_prompt (int): Number of videos to generate per prompt
max_sequence_length (int): Maximum sequence length for tokenization
dtype (torch.dtype): Data type for embeddings
Returns:
Tuple[torch.Tensor, torch.Tensor]: Text embeddings and cumulative sequence lengths
"""
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
full_texts = [self.prompt_template.format(p) for p in prompt]
inputs = self.tokenizer(
text=full_texts,
images=None,
videos=None,
max_length=max_sequence_length + self.prompt_template_encode_start_idx,
truncation=True,
return_tensors="pt",
padding=True,
).to(device)
embeds = self.text_encoder(
input_ids=inputs["input_ids"],
return_dict=True,
output_hidden_states=True,
)["hidden_states"][-1][:, self.prompt_template_encode_start_idx :]
attention_mask = inputs["attention_mask"][:, self.prompt_template_encode_start_idx :]
cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(dtype=torch.int32)
return embeds.to(dtype), cu_seqlens
def _encode_prompt_clip(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
"""
Encode prompt using CLIP text encoder.
This method processes the input prompt through the CLIP model to generate pooled embeddings that capture
semantic information.
Args:
prompt (Union[str, List[str]]): Input prompt or list of prompts
device (torch.device): Device to run encoding on
num_videos_per_prompt (int): Number of videos to generate per prompt
dtype (torch.dtype): Data type for embeddings
Returns:
torch.Tensor: Pooled text embeddings from CLIP
"""
device = device or self._execution_device
dtype = dtype or self.text_encoder_2.dtype
inputs = self.tokenizer_2(
prompt,
max_length=77,
truncation=True,
add_special_tokens=True,
padding="max_length",
return_tensors="pt",
).to(device)
pooled_embed = self.text_encoder_2(**inputs)["pooler_output"]
return pooled_embed.to(dtype)
def encode_prompt(
self,
prompt: Union[str, List[str]],
num_videos_per_prompt: int = 1,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes a single prompt (positive or negative) into text encoder hidden states.
This method combines embeddings from both Qwen2.5-VL and CLIP text encoders to create comprehensive text
representations for video generation.
Args:
prompt (`str` or `List[str]`):
Prompt to be encoded.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos to generate per prompt.
max_sequence_length (`int`, *optional*, defaults to 512):
Maximum sequence length for text encoding.
device (`torch.device`, *optional*):
Torch device.
dtype (`torch.dtype`, *optional*):
Torch dtype.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- Qwen text embeddings of shape (batch_size * num_videos_per_prompt, sequence_length, embedding_dim)
- CLIP pooled embeddings of shape (batch_size * num_videos_per_prompt, clip_embedding_dim)
- Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings of shape (batch_size *
num_videos_per_prompt + 1,)
"""
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
batch_size = len(prompt)
prompt = [prompt_clean(p) for p in prompt]
# Encode with Qwen2.5-VL
prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen(
prompt=prompt,
device=device,
max_sequence_length=max_sequence_length,
dtype=dtype,
)
# prompt_embeds_qwen shape: [batch_size, seq_len, embed_dim]
# Encode with CLIP
prompt_embeds_clip = self._encode_prompt_clip(
prompt=prompt,
device=device,
dtype=dtype,
)
# prompt_embeds_clip shape: [batch_size, clip_embed_dim]
# Repeat embeddings for num_videos_per_prompt
# Qwen embeddings: repeat sequence for each video, then reshape
prompt_embeds_qwen = prompt_embeds_qwen.repeat(
1, num_videos_per_prompt, 1
) # [batch_size, seq_len * num_videos_per_prompt, embed_dim]
# Reshape to [batch_size * num_videos_per_prompt, seq_len, embed_dim]
prompt_embeds_qwen = prompt_embeds_qwen.view(
batch_size * num_videos_per_prompt, -1, prompt_embeds_qwen.shape[-1]
)
# CLIP embeddings: repeat for each video
prompt_embeds_clip = prompt_embeds_clip.repeat(
1, num_videos_per_prompt, 1
) # [batch_size, num_videos_per_prompt, clip_embed_dim]
# Reshape to [batch_size * num_videos_per_prompt, clip_embed_dim]
prompt_embeds_clip = prompt_embeds_clip.view(batch_size * num_videos_per_prompt, -1)
# Repeat cumulative sequence lengths for num_videos_per_prompt
# Original cu_seqlens: [0, len1, len1+len2, ...]
# Need to repeat the differences and reconstruct for repeated prompts
# Original differences (lengths) for each prompt in the batch
original_lengths = prompt_cu_seqlens.diff() # [len1, len2, ...]
# Repeat the lengths for num_videos_per_prompt
repeated_lengths = original_lengths.repeat_interleave(
num_videos_per_prompt
) # [len1, len1, ..., len2, len2, ...]
# Reconstruct the cumulative lengths
repeated_cu_seqlens = torch.cat(
[torch.tensor([0], device=device, dtype=torch.int32), repeated_lengths.cumsum(0)]
)
return prompt_embeds_qwen, prompt_embeds_clip, repeated_cu_seqlens
def check_inputs(
self,
prompt,
negative_prompt,
height,
width,
prompt_embeds_qwen=None,
prompt_embeds_clip=None,
negative_prompt_embeds_qwen=None,
negative_prompt_embeds_clip=None,
prompt_cu_seqlens=None,
negative_prompt_cu_seqlens=None,
callback_on_step_end_tensor_inputs=None,
):
"""
Validate input parameters for the pipeline.
Args:
prompt: Input prompt
negative_prompt: Negative prompt for guidance
height: Video height
width: Video width
prompt_embeds_qwen: Pre-computed Qwen prompt embeddings
prompt_embeds_clip: Pre-computed CLIP prompt embeddings
negative_prompt_embeds_qwen: Pre-computed Qwen negative prompt embeddings
negative_prompt_embeds_clip: Pre-computed CLIP negative prompt embeddings
prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen positive prompt
negative_prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen negative prompt
callback_on_step_end_tensor_inputs: Callback tensor inputs
Raises:
ValueError: If inputs are invalid
"""
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
# Check for consistency within positive prompt embeddings and sequence lengths
if prompt_embeds_qwen is not None or prompt_embeds_clip is not None or prompt_cu_seqlens is not None:
if prompt_embeds_qwen is None or prompt_embeds_clip is None or prompt_cu_seqlens is None:
raise ValueError(
"If any of `prompt_embeds_qwen`, `prompt_embeds_clip`, or `prompt_cu_seqlens` is provided, "
"all three must be provided."
)
# Check for consistency within negative prompt embeddings and sequence lengths
if (
negative_prompt_embeds_qwen is not None
or negative_prompt_embeds_clip is not None
or negative_prompt_cu_seqlens is not None
):
if (
negative_prompt_embeds_qwen is None
or negative_prompt_embeds_clip is None
or negative_prompt_cu_seqlens is None
):
raise ValueError(
"If any of `negative_prompt_embeds_qwen`, `negative_prompt_embeds_clip`, or `negative_prompt_cu_seqlens` is provided, "
"all three must be provided."
)
# Check if prompt or embeddings are provided (either prompt or all required embedding components for positive)
if prompt is None and prompt_embeds_qwen is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds_qwen` (and corresponding `prompt_embeds_clip` and `prompt_cu_seqlens`). Cannot leave all undefined."
)
# Validate types for prompt and negative_prompt if provided
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and (
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
def prepare_latents(
self,
batch_size: int,
num_channels_latents: int = 16,
height: int = 480,
width: int = 832,
num_frames: int = 81,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Prepare initial latent variables for video generation.
This method creates random noise latents or uses provided latents as starting point for the denoising process.
Args:
batch_size (int): Number of videos to generate
num_channels_latents (int): Number of channels in latent space
height (int): Height of generated video
width (int): Width of generated video
num_frames (int): Number of frames in video
dtype (torch.dtype): Data type for latents
device (torch.device): Device to create latents on
generator (torch.Generator): Random number generator
latents (torch.Tensor): Pre-existing latents to use
Returns:
torch.Tensor: Prepared latent tensor
"""
if latents is not None:
return latents.to(device=device, dtype=dtype)
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
shape = (
batch_size,
num_latent_frames,
int(height) // self.vae_scale_factor_spatial,
int(width) // self.vae_scale_factor_spatial,
num_channels_latents,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if self.transformer.visual_cond:
# For visual conditioning, concatenate with zeros and mask
visual_cond = torch.zeros_like(latents)
visual_cond_mask = torch.zeros(
[
batch_size,
num_latent_frames,
int(height) // self.vae_scale_factor_spatial,
int(width) // self.vae_scale_factor_spatial,
1,
],
dtype=latents.dtype,
device=latents.device,
)
latents = torch.cat([latents, visual_cond, visual_cond_mask], dim=-1)
return latents
@property
def guidance_scale(self):
"""Get the current guidance scale value."""
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
"""Check if classifier-free guidance is enabled."""
return self._guidance_scale > 1.0
@property
def num_timesteps(self):
"""Get the number of denoising timesteps."""
return self._num_timesteps
@property
def interrupt(self):
"""Check if generation has been interrupted."""
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 512,
width: int = 768,
num_frames: int = 121,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds_qwen: Optional[torch.Tensor] = None,
prompt_embeds_clip: Optional[torch.Tensor] = None,
negative_prompt_embeds_qwen: Optional[torch.Tensor] = None,
negative_prompt_embeds_clip: Optional[torch.Tensor] = None,
prompt_cu_seqlens: Optional[torch.Tensor] = None,
negative_prompt_cu_seqlens: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
**kwargs,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the video generation. If not defined, pass `prompt_embeds` instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to avoid during video generation. If not defined, pass `negative_prompt_embeds`
instead. Ignored when not using guidance (`guidance_scale` < `1`).
height (`int`, defaults to `512`):
The height in pixels of the generated video.
width (`int`, defaults to `768`):
The width in pixels of the generated video.
num_frames (`int`, defaults to `25`):
The number of frames in the generated video.
num_inference_steps (`int`, defaults to `50`):
The number of denoising steps.
guidance_scale (`float`, defaults to `5.0`):
Guidance scale as defined in classifier-free guidance.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A torch generator to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated video.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`KandinskyPipelineOutput`].
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function that is called at the end of each denoising step.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function.
max_sequence_length (`int`, defaults to `512`):
The maximum sequence length for text encoding.
Examples:
Returns:
[`~KandinskyPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`KandinskyPipelineOutput`] is returned, otherwise a `tuple` is returned
where the first element is a list with the generated images.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
prompt_embeds_qwen=prompt_embeds_qwen,
prompt_embeds_clip=prompt_embeds_clip,
negative_prompt_embeds_qwen=negative_prompt_embeds_qwen,
negative_prompt_embeds_clip=negative_prompt_embeds_clip,
prompt_cu_seqlens=prompt_cu_seqlens,
negative_prompt_cu_seqlens=negative_prompt_cu_seqlens,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
if num_frames % self.vae_scale_factor_temporal != 1:
logger.warning(
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
)
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
num_frames = max(num_frames, 1)
self._guidance_scale = guidance_scale
self._interrupt = False
device = self._execution_device
dtype = self.transformer.dtype
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
prompt = [prompt]
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds_qwen.shape[0]
# 3. Encode input prompt
if prompt_embeds_qwen is None:
prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = self.encode_prompt(
prompt=prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if self.do_classifier_free_guidance:
if negative_prompt is None:
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * len(prompt) if prompt is not None else [negative_prompt]
elif len(negative_prompt) != len(prompt):
raise ValueError(
f"`negative_prompt` must have same length as `prompt`. Got {len(negative_prompt)} vs {len(prompt)}."
)
if negative_prompt_embeds_qwen is None:
negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_cu_seqlens = self.encode_prompt(
prompt=negative_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_visual_dim
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
dtype,
device,
generator,
latents,
)
# 6. Prepare rope positions for positional encoding
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
visual_rope_pos = [
torch.arange(num_latent_frames, device=device),
torch.arange(height // self.vae_scale_factor_spatial // 2, device=device),
torch.arange(width // self.vae_scale_factor_spatial // 2, device=device),
]
text_rope_pos = torch.arange(prompt_cu_seqlens.diff().max().item(), device=device)
negative_text_rope_pos = (
torch.arange(negative_cu_seqlens.diff().max().item(), device=device)
if negative_cu_seqlens is not None
else None
)
# 7. Sparse Params for efficient attention
sparse_params = self.get_sparse_params(latents, device)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
timestep = t.unsqueeze(0).repeat(batch_size * num_videos_per_prompt)
# Predict noise residual
pred_velocity = self.transformer(
hidden_states=latents.to(dtype),
encoder_hidden_states=prompt_embeds_qwen.to(dtype),
pooled_projections=prompt_embeds_clip.to(dtype),
timestep=timestep.to(dtype),
visual_rope_pos=visual_rope_pos,
text_rope_pos=text_rope_pos,
scale_factor=(1, 2, 2),
sparse_params=sparse_params,
return_dict=True,
).sample
if self.do_classifier_free_guidance and negative_prompt_embeds_qwen is not None:
uncond_pred_velocity = self.transformer(
hidden_states=latents.to(dtype),
encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype),
pooled_projections=negative_prompt_embeds_clip.to(dtype),
timestep=timestep.to(dtype),
visual_rope_pos=visual_rope_pos,
text_rope_pos=negative_text_rope_pos,
scale_factor=(1, 2, 2),
sparse_params=sparse_params,
return_dict=True,
).sample
pred_velocity = uncond_pred_velocity + guidance_scale * (pred_velocity - uncond_pred_velocity)
# Compute previous sample using the scheduler
latents[:, :, :, :, :num_channels_latents] = self.scheduler.step(
pred_velocity, t, latents[:, :, :, :, :num_channels_latents], return_dict=False
)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds_qwen = callback_outputs.pop("prompt_embeds_qwen", prompt_embeds_qwen)
prompt_embeds_clip = callback_outputs.pop("prompt_embeds_clip", prompt_embeds_clip)
negative_prompt_embeds_qwen = callback_outputs.pop(
"negative_prompt_embeds_qwen", negative_prompt_embeds_qwen
)
negative_prompt_embeds_clip = callback_outputs.pop(
"negative_prompt_embeds_clip", negative_prompt_embeds_clip
)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# 8. Post-processing - extract main latents
latents = latents[:, :, :, :, :num_channels_latents]
# 9. Decode latents to video
if output_type != "latent":
latents = latents.to(self.vae.dtype)
# Reshape and normalize latents
video = latents.reshape(
batch_size,
num_videos_per_prompt,
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
num_channels_latents,
)
video = video.permute(0, 1, 5, 2, 3, 4) # [batch, num_videos, channels, frames, height, width]
video = video.reshape(
batch_size * num_videos_per_prompt,
num_channels_latents,
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
# Normalize and decode through VAE
video = video / self.vae.config.scaling_factor
video = self.vae.decode(video).sample
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return KandinskyPipelineOutput(frames=video)

View File

@@ -0,0 +1,20 @@
from dataclasses import dataclass
import torch
from diffusers.utils import BaseOutput
@dataclass
class KandinskyPipelineOutput(BaseOutput):
r"""
Output class for Wan pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: torch.Tensor

View File

@@ -918,6 +918,21 @@ class Kandinsky3UNet(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class Kandinsky5Transformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class LatteTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -1247,6 +1247,21 @@ class Kandinsky3Pipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class Kandinsky5T2VPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class KandinskyCombinedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]