1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

add transformer pipeline first version

This commit is contained in:
leffff
2025-10-04 10:10:23 +00:00
parent 7242b5ff62
commit d53f848720
10 changed files with 1541 additions and 1 deletions

View File

@@ -260,6 +260,7 @@ else:
"VQModel",
"WanTransformer3DModel",
"WanVACETransformer3DModel",
"Kandinsky5Transformer3DModel",
"attention_backend",
]
)
@@ -618,6 +619,7 @@ else:
"WanPipeline",
"WanVACEPipeline",
"WanVideoToVideoPipeline",
"Kandinsky5T2VPipeline",
"WuerstchenCombinedPipeline",
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
@@ -947,6 +949,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VQModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
Kandinsky5Transformer3DModel,
attention_backend,
)
from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks
@@ -1275,6 +1278,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
WanPipeline,
WanVACEPipeline,
WanVideoToVideoPipeline,
Kandinsky5T2VPipeline,
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,

View File

@@ -77,6 +77,7 @@ if is_torch_available():
"SanaLoraLoaderMixin",
"Lumina2LoraLoaderMixin",
"WanLoraLoaderMixin",
"KandinskyLoraLoaderMixin",
"HiDreamImageLoraLoaderMixin",
"SkyReelsV2LoraLoaderMixin",
"QwenImageLoraLoaderMixin",
@@ -126,6 +127,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionLoraLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
WanLoraLoaderMixin,
KandinskyLoraLoaderMixin
)
from .single_file import FromSingleFileMixin
from .textual_inversion import TextualInversionLoaderMixin

View File

@@ -3638,6 +3638,292 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
"""
super().unfuse_lora(components=components, **kwargs)
class KandinskyLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`Kandinsky5Transformer3DModel`],
"""
_lora_loadable_modules = ["transformer"]
transformer_name = TRANSFORMER_NAME
@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
Return state dict for lora weights and the network alphas.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* of a pretrained model hosted on the Hub.
- A path to a *directory* containing the model weights.
- A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository.
weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file.
use_safetensors (`bool`, *optional*):
Whether to use safetensors for loading.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata.
"""
# Load the main state dict first which has the LoRA layers
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer`
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model.
hotswap (`bool`, *optional*):
Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.
kwargs (`dict`, *optional*):
See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
# Load LoRA into transformer
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
def load_lora_into_transformer(
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
Load the LoRA layers specified in `state_dict` into `transformer`.
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters.
transformer (`Kandinsky5Transformer3DModel`):
The transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights.
hotswap (`bool`, *optional*):
See [`~loaders.KandinskyLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata.
"""
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata=None,
):
r"""
Save the LoRA parameters corresponding to the transformer and text encoders.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to.
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `transformer`.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process.
save_function (`Callable`):
The function to use to save the state dictionary.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer.
"""
lora_layers = {}
lora_metadata = {}
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if not lora_layers:
raise ValueError(
"You must pass at least one of `transformer_lora_layers`"
)
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
def fuse_lora(
self,
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
Args:
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing.
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing.
Example:
```py
from diffusers import Kandinsky5T2VPipeline
pipeline = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V")
pipeline.load_lora_weights("path/to/lora.safetensors")
pipeline.fuse_lora(lora_scale=0.7)
```
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
Reverses the effect of [`pipe.fuse_lora()`].
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
"""
super().unfuse_lora(components=components, **kwargs)
class WanLoraLoaderMixin(LoraBaseMixin):
r"""
@@ -4802,4 +5088,4 @@ class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
def __init__(self, *args, **kwargs):
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
deprecate("LoraLoaderMixin", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
super().__init__(*args, **kwargs)

View File

@@ -101,6 +101,7 @@ if is_torch_available():
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
@@ -200,6 +201,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
TransformerTemporalModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
Kandinsky5Transformer3DModel,
)
from .unets import (
I2VGenXLUNet,

View File

@@ -37,3 +37,4 @@ if is_torch_available():
from .transformer_temporal import TransformerTemporalModel
from .transformer_wan import WanTransformer3DModel
from .transformer_wan_vace import WanVACETransformer3DModel
from .transformer_kandinsky import Kandinsky5Transformer3DModel

View File

@@ -0,0 +1,630 @@
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Dict, Optional, Tuple, Union, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import FP32LayerNorm
logger = logging.get_logger(__name__)
# @torch.compile()
@torch.autocast(device_type="cuda", dtype=torch.float32)
def apply_scale_shift_norm(norm, x, scale, shift):
return (norm(x) * (scale + 1.0) + shift).to(torch.bfloat16)
# @torch.compile()
@torch.autocast(device_type="cuda", dtype=torch.float32)
def apply_gate_sum(x, out, gate):
return (x + gate * out).to(torch.bfloat16)
# @torch.compile()
@torch.autocast(device_type="cuda", enabled=False)
def apply_rotary(x, rope):
x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32)
x_out = (rope * x_).sum(dim=-1)
return x_out.reshape(*x.shape).to(torch.bfloat16)
@torch.autocast(device_type="cuda", enabled=False)
def get_freqs(dim, max_period=10000.0):
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=dim, dtype=torch.float32)
/ dim
)
return freqs
class TimeEmbeddings(nn.Module):
def __init__(self, model_dim, time_dim, max_period=10000.0):
super().__init__()
assert model_dim % 2 == 0
self.model_dim = model_dim
self.max_period = max_period
self.register_buffer(
"freqs", get_freqs(model_dim // 2, max_period), persistent=False
)
self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
self.activation = nn.SiLU()
self.out_layer = nn.Linear(time_dim, time_dim, bias=True)
def forward(self, time):
args = torch.outer(time, self.freqs.to(device=time.device))
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
return time_embed
class TextEmbeddings(nn.Module):
def __init__(self, text_dim, model_dim):
super().__init__()
self.in_layer = nn.Linear(text_dim, model_dim, bias=True)
self.norm = nn.LayerNorm(model_dim, elementwise_affine=True)
def forward(self, text_embed):
text_embed = self.in_layer(text_embed)
return self.norm(text_embed).type_as(text_embed)
class VisualEmbeddings(nn.Module):
def __init__(self, visual_dim, model_dim, patch_size):
super().__init__()
self.patch_size = patch_size
self.in_layer = nn.Linear(math.prod(patch_size) * visual_dim, model_dim)
def forward(self, x):
batch_size, duration, height, width, dim = x.shape
x = (
x.view(
batch_size,
duration // self.patch_size[0],
self.patch_size[0],
height // self.patch_size[1],
self.patch_size[1],
width // self.patch_size[2],
self.patch_size[2],
dim,
)
.permute(0, 1, 3, 5, 2, 4, 6, 7)
.flatten(4, 7)
)
return self.in_layer(x)
class RoPE1D(nn.Module):
"""
1D Rotary Positional Embeddings for text sequences.
Args:
dim: Dimension of the rotary embeddings
max_pos: Maximum sequence length
max_period: Maximum period for sinusoidal embeddings
"""
def __init__(self, dim, max_pos=1024, max_period=10000.0):
super().__init__()
self.max_period = max_period
self.dim = dim
self.max_pos = max_pos
freq = get_freqs(dim // 2, max_period)
pos = torch.arange(max_pos, dtype=freq.dtype)
self.register_buffer("args", torch.outer(pos, freq), persistent=False)
def forward(self, pos):
"""
Args:
pos: Position indices of shape [seq_len] or [batch_size, seq_len]
Returns:
Rotary embeddings of shape [seq_len, 1, 2, 2]
"""
args = self.args[pos]
cosine = torch.cos(args)
sine = torch.sin(args)
rope = torch.stack([cosine, -sine, sine, cosine], dim=-1)
rope = rope.view(*rope.shape[:-1], 2, 2)
return rope.unsqueeze(-4)
class RoPE3D(nn.Module):
def __init__(self, axes_dims, max_pos=(128, 128, 128), max_period=10000.0):
super().__init__()
self.axes_dims = axes_dims
self.max_pos = max_pos
self.max_period = max_period
for i, (axes_dim, ax_max_pos) in enumerate(zip(axes_dims, max_pos)):
freq = get_freqs(axes_dim // 2, max_period)
pos = torch.arange(ax_max_pos, dtype=freq.dtype)
self.register_buffer(f"args_{i}", torch.outer(pos, freq), persistent=False)
@torch.autocast(device_type="cuda", enabled=False)
def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)):
batch_size, duration, height, width = shape
args_t = self.args_0[pos[0]] / scale_factor[0]
args_h = self.args_1[pos[1]] / scale_factor[1]
args_w = self.args_2[pos[2]] / scale_factor[2]
# Replicate the original logic with batch dimension
args_t_expanded = args_t.view(1, duration, 1, 1, -1).expand(batch_size, -1, height, width, -1)
args_h_expanded = args_h.view(1, 1, height, 1, -1).expand(batch_size, duration, -1, width, -1)
args_w_expanded = args_w.view(1, 1, 1, width, -1).expand(batch_size, duration, height, -1, -1)
# Concatenate along the last dimension
args = torch.cat([args_t_expanded, args_h_expanded, args_w_expanded], dim=-1) # [B, D, H, W, F]
cosine = torch.cos(args)
sine = torch.sin(args)
rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) # [B, D, H, W, F, 4]
rope = rope.view(*rope.shape[:-1], 2, 2) # [B, D, H, W, F, 2, 2]
return rope.unsqueeze(-4) # [B, D, H, 1, W, F, 2, 2]
class Modulation(nn.Module):
def __init__(self, time_dim, model_dim, num_params):
super().__init__()
self.activation = nn.SiLU()
self.out_layer = nn.Linear(time_dim, num_params * model_dim)
self.out_layer.weight.data.zero_()
self.out_layer.bias.data.zero_()
def forward(self, x):
return self.out_layer(self.activation(x))
class MultiheadSelfAttentionEnc(nn.Module):
def __init__(self, num_channels, head_dim):
super().__init__()
assert num_channels % head_dim == 0
self.num_heads = num_channels // head_dim
self.to_query = nn.Linear(num_channels, num_channels, bias=True)
self.to_key = nn.Linear(num_channels, num_channels, bias=True)
self.to_value = nn.Linear(num_channels, num_channels, bias=True)
self.query_norm = nn.RMSNorm(head_dim)
self.key_norm = nn.RMSNorm(head_dim)
self.out_layer = nn.Linear(num_channels, num_channels, bias=True)
def forward(self, x, rope):
query = self.to_query(x)
key = self.to_key(x)
value = self.to_value(x)
shape = query.shape[:-1]
query = query.reshape(*shape, self.num_heads, -1)
key = key.reshape(*shape, self.num_heads, -1)
value = value.reshape(*shape, self.num_heads, -1)
query = self.query_norm(query.float()).type_as(query)
key = self.key_norm(key.float()).type_as(key)
query = apply_rotary(query, rope).type_as(query)
key = apply_rotary(key, rope).type_as(key)
# Use torch's scaled_dot_product_attention
out = F.scaled_dot_product_attention(
query,
key,
value,
).flatten(-2, -1)
out = self.out_layer(out)
return out
class MultiheadSelfAttentionDec(nn.Module):
def __init__(self, num_channels, head_dim):
super().__init__()
assert num_channels % head_dim == 0
self.num_heads = num_channels // head_dim
self.to_query = nn.Linear(num_channels, num_channels, bias=True)
self.to_key = nn.Linear(num_channels, num_channels, bias=True)
self.to_value = nn.Linear(num_channels, num_channels, bias=True)
self.query_norm = nn.RMSNorm(head_dim)
self.key_norm = nn.RMSNorm(head_dim)
self.out_layer = nn.Linear(num_channels, num_channels, bias=True)
def forward(self, x, rope, sparse_params=None):
query = self.to_query(x)
key = self.to_key(x)
value = self.to_value(x)
shape = query.shape[:-1]
query = query.reshape(*shape, self.num_heads, -1)
key = key.reshape(*shape, self.num_heads, -1)
value = value.reshape(*shape, self.num_heads, -1)
query = self.query_norm(query.float()).type_as(query)
key = self.key_norm(key.float()).type_as(key)
query = apply_rotary(query, rope).type_as(query)
key = apply_rotary(key, rope).type_as(key)
# Use standard attention (can be extended with sparse attention)
out = F.scaled_dot_product_attention(
query,
key,
value,
).flatten(-2, -1)
out = self.out_layer(out)
return out
class MultiheadCrossAttention(nn.Module):
def __init__(self, num_channels, head_dim):
super().__init__()
assert num_channels % head_dim == 0
self.num_heads = num_channels // head_dim
self.to_query = nn.Linear(num_channels, num_channels, bias=True)
self.to_key = nn.Linear(num_channels, num_channels, bias=True)
self.to_value = nn.Linear(num_channels, num_channels, bias=True)
self.query_norm = nn.RMSNorm(head_dim)
self.key_norm = nn.RMSNorm(head_dim)
self.out_layer = nn.Linear(num_channels, num_channels, bias=True)
def forward(self, x, cond):
query = self.to_query(x)
key = self.to_key(cond)
value = self.to_value(cond)
shape, cond_shape = query.shape[:-1], key.shape[:-1]
query = query.reshape(*shape, self.num_heads, -1)
key = key.reshape(*cond_shape, self.num_heads, -1)
value = value.reshape(*cond_shape, self.num_heads, -1)
query = self.query_norm(query.float()).type_as(query)
key = self.key_norm(key.float()).type_as(key)
out = F.scaled_dot_product_attention(
query.permute(0, 2, 1, 3),
key.permute(0, 2, 1, 3),
value.permute(0, 2, 1, 3),
).permute(0, 2, 1, 3).flatten(-2, -1)
out = self.out_layer(out)
return out
class FeedForward(nn.Module):
def __init__(self, dim, ff_dim):
super().__init__()
self.in_layer = nn.Linear(dim, ff_dim, bias=False)
self.activation = nn.GELU()
self.out_layer = nn.Linear(ff_dim, dim, bias=False)
def forward(self, x):
return self.out_layer(self.activation(self.in_layer(x)))
class TransformerEncoderBlock(nn.Module):
def __init__(self, model_dim, time_dim, ff_dim, head_dim):
super().__init__()
self.text_modulation = Modulation(time_dim, model_dim, 6)
self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
self.self_attention = MultiheadSelfAttentionEnc(model_dim, head_dim)
self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
self.feed_forward = FeedForward(model_dim, ff_dim)
def forward(self, x, time_embed, rope):
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1)
shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
out = self.self_attention_norm(x)
out = out * (scale + 1.0) + shift
out = self.self_attention(out, rope)
x = x + gate * out
shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
out = self.feed_forward_norm(x)
out = out * (scale + 1.0) + shift
out = self.feed_forward(out)
x = x + gate * out
return x
class TransformerDecoderBlock(nn.Module):
def __init__(self, model_dim, time_dim, ff_dim, head_dim):
super().__init__()
self.visual_modulation = Modulation(time_dim, model_dim, 9)
self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
self.self_attention = MultiheadSelfAttentionDec(model_dim, head_dim)
self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
self.cross_attention = MultiheadCrossAttention(model_dim, head_dim)
self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
self.feed_forward = FeedForward(model_dim, ff_dim)
def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params):
self_attn_params, cross_attn_params, ff_params = torch.chunk(
self.visual_modulation(time_embed), 3, dim=-1
)
shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
visual_out = self.self_attention_norm(visual_embed)
visual_out = visual_out * (scale + 1.0) + shift
visual_out = self.self_attention(visual_out, rope, sparse_params)
visual_embed = visual_embed + gate * visual_out
shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1)
visual_out = self.cross_attention_norm(visual_embed)
visual_out = visual_out * (scale + 1.0) + shift
visual_out = self.cross_attention(visual_out, text_embed)
visual_embed = visual_embed + gate * visual_out
shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
visual_out = self.feed_forward_norm(visual_embed)
visual_out = visual_out * (scale + 1.0) + shift
visual_out = self.feed_forward(visual_out)
visual_embed = visual_embed + gate * visual_out
return visual_embed
class OutLayer(nn.Module):
def __init__(self, model_dim, time_dim, visual_dim, patch_size):
super().__init__()
self.patch_size = patch_size
self.modulation = Modulation(time_dim, model_dim, 2)
self.norm = nn.LayerNorm(model_dim, elementwise_affine=False)
self.out_layer = nn.Linear(
model_dim, math.prod(patch_size) * visual_dim, bias=True
)
def forward(self, visual_embed, text_embed, time_embed):
# Handle the new batch dimension: [batch, duration, height, width, model_dim]
batch_size, duration, height, width, _ = visual_embed.shape
shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1)
# Apply modulation with proper broadcasting for the new shape
visual_embed = apply_scale_shift_norm(
self.norm,
visual_embed,
scale[:, None, None, None], # [batch, 1, 1, 1, model_dim] -> [batch, 1, 1, 1]
shift[:, None, None, None], # [batch, 1, 1, 1, model_dim] -> [batch, 1, 1, 1]
).type_as(visual_embed)
x = self.out_layer(visual_embed)
# Now x has shape [batch, duration, height, width, patch_prod * visual_dim]
x = (
x.view(
batch_size,
duration,
height,
width,
-1,
self.patch_size[0],
self.patch_size[1],
self.patch_size[2],
)
.permute(0, 5, 1, 6, 2, 7, 3, 4) # [batch, patch_t, duration, patch_h, height, patch_w, width, features]
.flatten(1, 2) # [batch, patch_t * duration, height, patch_w, width, features]
.flatten(2, 3) # [batch, patch_t * duration, patch_h * height, width, features]
.flatten(3, 4) # [batch, patch_t * duration, patch_h * height, patch_w * width]
)
return x
@maybe_allow_in_graph
class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin):
r"""
A 3D Transformer model for video generation used in Kandinsky 5.0.
This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods implemented for all models (such as downloading or saving).
Args:
in_visual_dim (`int`, defaults to 16):
Number of channels in the input visual latent.
out_visual_dim (`int`, defaults to 16):
Number of channels in the output visual latent.
time_dim (`int`, defaults to 512):
Dimension of the time embeddings.
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
Patch size for the visual embeddings (temporal, height, width).
model_dim (`int`, defaults to 1792):
Hidden dimension of the transformer model.
ff_dim (`int`, defaults to 7168):
Intermediate dimension of the feed-forward networks.
num_text_blocks (`int`, defaults to 2):
Number of transformer blocks in the text encoder.
num_visual_blocks (`int`, defaults to 32):
Number of transformer blocks in the visual decoder.
axes_dims (`Tuple[int]`, defaults to `(16, 24, 24)`):
Dimensions for the rotary positional embeddings (temporal, height, width).
visual_cond (`bool`, defaults to `True`):
Whether to use visual conditioning (for image/video conditioning).
in_text_dim (`int`, defaults to 3584):
Dimension of the text embeddings from Qwen2.5-VL.
in_text_dim2 (`int`, defaults to 768):
Dimension of the pooled text embeddings from CLIP.
"""
@register_to_config
def __init__(
self,
in_visual_dim: int = 16,
out_visual_dim: int = 16,
time_dim: int = 512,
patch_size: Tuple[int, int, int] = (1, 2, 2),
model_dim: int = 1792,
ff_dim: int = 7168,
num_text_blocks: int = 2,
num_visual_blocks: int = 32,
axes_dims: Tuple[int, int, int] = (16, 24, 24),
visual_cond: bool = True,
in_text_dim: int = 3584,
in_text_dim2: int = 768,
):
super().__init__()
self.in_visual_dim = in_visual_dim
self.model_dim = model_dim
self.patch_size = patch_size
self.visual_cond = visual_cond
# Calculate head dimension for attention
head_dim = sum(axes_dims)
# Determine visual embedding dimension based on conditioning
visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim
# 1. Embedding layers
self.time_embeddings = TimeEmbeddings(model_dim, time_dim)
self.text_embeddings = TextEmbeddings(in_text_dim, model_dim)
self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim)
self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size)
# 2. Rotary positional embeddings
self.text_rope_embeddings = RoPE1D(head_dim)
self.visual_rope_embeddings = RoPE3D(axes_dims)
# 3. Transformer blocks
self.text_transformer_blocks = nn.ModuleList([
TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim)
for _ in range(num_text_blocks)
])
self.visual_transformer_blocks = nn.ModuleList([
TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim)
for _ in range(num_visual_blocks)
])
# 4. Output layer
self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
pooled_text_embed: torch.Tensor,
timestep: torch.Tensor,
visual_rope_pos: List[torch.Tensor],
text_rope_pos: torch.Tensor,
scale_factor: Tuple[float, float, float] = (1.0, 1.0, 1.0),
sparse_params: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
Forward pass of the Kandinsky 5.0 3D Transformer.
Args:
hidden_states (`torch.Tensor`):
Input visual latent tensor of shape `(batch_size, num_frames, height, width, channels)`.
encoder_hidden_states (`torch.Tensor`):
Text embeddings from Qwen2.5-VL of shape `(batch_size, sequence_length, text_dim)`.
pooled_text_embed (`torch.Tensor`):
Pooled text embeddings from CLIP of shape `(batch_size, pooled_text_dim)`.
timestep (`torch.Tensor`):
Timestep tensor of shape `(batch_size,)` or `(batch_size * num_frames,)`.
visual_rope_pos (`List[torch.Tensor]`):
List of tensors for visual rotary positional embeddings [temporal, height, width].
text_rope_pos (`torch.Tensor`):
Tensor for text rotary positional embeddings.
scale_factor (`Tuple[float, float, float]`, defaults to `(1.0, 1.0, 1.0)`):
Scale factors for rotary positional embeddings.
sparse_params (`Dict[str, Any]`, *optional*):
Parameters for sparse attention.
return_dict (`bool`, defaults to `True`):
Whether to return a dictionary or a tensor.
Returns:
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
If `return_dict` is `True`, a [`~models.transformer_2d.Transformer2DModelOutput`] is returned,
otherwise a `tuple` where the first element is the sample tensor.
"""
batch_size, num_frames, height, width, channels = hidden_states.shape
# 1. Process text embeddings
text_embed = self.text_embeddings(encoder_hidden_states)
time_embed = self.time_embeddings(timestep)
# Add pooled text embedding to time embedding
pooled_embed = self.pooled_text_embeddings(pooled_text_embed)
time_embed = time_embed + pooled_embed
# visual_embed shape: [batch_size, seq_len, model_dim]
visual_embed = self.visual_embeddings(hidden_states)
# 3. Text rotary embeddings
text_rope = self.text_rope_embeddings(text_rope_pos)
# 4. Text transformer blocks
for text_block in self.text_transformer_blocks:
if self.gradient_checkpointing and self.training:
text_embed = torch.utils.checkpoint.checkpoint(
text_block, text_embed, time_embed, text_rope, use_reentrant=False
)
else:
text_embed = text_block(text_embed, time_embed, text_rope)
# 5. Prepare visual rope
visual_shape = visual_embed.shape[:-1]
visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor)
visual_embed = visual_embed.reshape(visual_embed.shape[0], -1, visual_embed.shape[-1])
visual_rope = visual_rope.view(visual_rope.shape[0], -1, *list(visual_rope.shape[-4:]))
# 6. Visual transformer blocks
for visual_block in self.visual_transformer_blocks:
if self.gradient_checkpointing and self.training:
visual_embed = torch.utils.checkpoint.checkpoint(
visual_block,
visual_embed,
text_embed,
time_embed,
visual_rope,
# visual_rope_flat,
sparse_params,
use_reentrant=False,
)
else:
visual_embed = visual_block(
visual_embed, text_embed, time_embed, visual_rope, sparse_params
)
# 7. Output projection
visual_embed = visual_embed.reshape(batch_size, num_frames, height // 2, width // 2, -1)
output = self.out_layer(visual_embed, text_embed, time_embed)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -382,6 +382,7 @@ else:
"WuerstchenPriorPipeline",
]
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline", "WanVACEPipeline"]
_import_structure["kandinsky5"] = ["Kandinsky5T2VPipeline"]
_import_structure["skyreels_v2"] = [
"SkyReelsV2DiffusionForcingPipeline",
"SkyReelsV2DiffusionForcingImageToVideoPipeline",
@@ -787,6 +788,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
)
from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline
from .wan import WanImageToVideoPipeline, WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline
from .kandinsky5 import Kandinsky5T2VPipeline
from .wuerstchen import (
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,

View File

@@ -0,0 +1,48 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_kandinsky"] = ["Kandinsky5T2VPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_kandinsky import Kandinsky5T2VPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,545 @@
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html
from typing import Any, Callable, Dict, List, Optional, Union
import regex as re
import torch
from transformers import Qwen2TokenizerFast, Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, AutoProcessor, CLIPTextModel, CLIPTokenizer
import torchvision
from torchvision.transforms import ToPILImage
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import KandinskyLoraLoaderMixin
from ...models import AutoencoderKLHunyuanVideo
from ...models.transformers import Kandinsky5Transformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import KandinskyPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_ftfy_available():
import ftfy
logger = logging.get_logger(__name__)
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import Kandinsky5T2VPipeline, Kandinsky5Transformer3DModel
>>> from diffusers.utils import export_to_video
>>> pipe = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V")
>>> pipe = pipe.to("cuda")
>>> prompt = "A cat and a dog baking a cake together in a kitchen."
>>> negative_prompt = "Bright tones, overexposed, static, blurred details"
>>> output = pipe(
... prompt=prompt,
... negative_prompt=negative_prompt,
... height=512,
... width=768,
... num_frames=25,
... num_inference_steps=50,
... guidance_scale=5.0,
... ).frames[0]
>>> export_to_video(output, "output.mp4", fps=6)
```
"""
class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
r"""
Pipeline for text-to-video generation using Kandinsky 5.0.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
transformer ([`Kandinsky5Transformer3DModel`]):
Conditional Transformer to denoise the encoded video latents.
vae ([`AutoencoderKLHunyuanVideo`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
text_encoder ([`Qwen2_5_VLForConditionalGeneration`]):
Frozen text-encoder (Qwen2.5-VL).
tokenizer ([`AutoProcessor`]):
Tokenizer for Qwen2.5-VL.
text_encoder_2 ([`CLIPTextModel`]):
Frozen CLIP text encoder.
tokenizer_2 ([`CLIPTokenizer`]):
Tokenizer for CLIP.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
transformer: Kandinsky5Transformer3DModel,
vae: AutoencoderKLHunyuanVideo,
text_encoder: Qwen2_5_VLForConditionalGeneration,
tokenizer: Qwen2VLProcessor,
text_encoder_2: CLIPTextModel,
tokenizer_2: CLIPTokenizer,
scheduler: FlowMatchEulerDiscreteScheduler,
):
super().__init__()
self.register_modules(
transformer=transformer,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
scheduler=scheduler,
)
self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio
self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio
def _encode_prompt_qwen(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 256,
):
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
# Kandinsky specific prompt template
prompt_template = "\n".join([
"<|im_start|>system\nYou are a promt engineer. Describe the video in detail.",
"Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.",
"Describe the location of the video, main characters or objects and their action.",
"Describe the dynamism of the video and presented actions.",
"Name the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.",
"Describe the visual effects, postprocessing and transitions if they are presented in the video.",
"Pay attention to the order of key actions shown in the scene.<|im_end|>",
"<|im_start|>user\n{}<|im_end|>",
])
crop_start = 129
full_texts = [prompt_template.format(p) for p in prompt]
inputs = self.tokenizer(
text=full_texts,
images=None,
videos=None,
max_length=max_sequence_length + crop_start,
truncation=True,
return_tensors="pt",
padding=True,
).to(device)
with torch.no_grad():
embeds = self.text_encoder(
input_ids=inputs["input_ids"],
return_dict=True,
output_hidden_states=True,
)["hidden_states"][-1][:, crop_start:]
attention_mask = inputs["attention_mask"][:, crop_start:]
embeds = embeds[attention_mask.bool()]
cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0)
cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to(dtype=torch.int32)
# duplicate for each generation per prompt
batch_size = len(prompt)
seq_len = embeds.shape[0] // batch_size
embeds = embeds.reshape(batch_size, seq_len, -1)
embeds = embeds.repeat(1, num_videos_per_prompt, 1)
embeds = embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return embeds, cu_seqlens
def _encode_prompt_clip(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
num_videos_per_prompt: int = 1,
):
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
inputs = self.tokenizer_2(
prompt,
max_length=77,
truncation=True,
add_special_tokens=True,
padding="max_length",
return_tensors="pt",
).to(device)
with torch.no_grad():
pooled_embed = self.text_encoder_2(**inputs)["pooler_output"]
# duplicate for each generation per prompt
batch_size = len(prompt)
pooled_embed = pooled_embed.repeat(1, num_videos_per_prompt, 1)
pooled_embed = pooled_embed.view(batch_size * num_videos_per_prompt, -1)
return pooled_embed
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
device: Optional[torch.device] = None,
):
device = device or self._execution_device
# Encode with Qwen2.5-VL
prompt_embeds, prompt_cu_seqlens = self._encode_prompt_qwen(
prompt, device, num_videos_per_prompt
)
pooled_embed = self._encode_prompt_clip(prompt, device, num_videos_per_prompt)
if do_classifier_free_guidance:
negative_prompt = negative_prompt or ""
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_embeds, negative_cu_seqlens = self._encode_prompt_qwen(
negative_prompt, device, num_videos_per_prompt
)
negative_pooled_embed = self._encode_prompt_clip(negative_prompt, device, num_videos_per_prompt)
else:
negative_prompt_embeds = None
negative_pooled_embed = None
negative_cu_seqlens = None
text_embeds = {
"text_embeds": prompt_embeds,
"pooled_embed": pooled_embed,
}
negative_text_embeds = {
"text_embeds": negative_prompt_embeds,
"pooled_embed": negative_pooled_embed,
} if do_classifier_free_guidance else None
return text_embeds, negative_text_embeds, prompt_cu_seqlens, negative_cu_seqlens
def prepare_latents(
self,
batch_size: int,
num_channels_latents: int = 16,
height: int = 480,
width: int = 832,
num_frames: int = 81,
visual_cond: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
return latents.to(device=device, dtype=dtype)
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
shape = (
batch_size,
num_latent_frames,
int(height) // self.vae_scale_factor_spatial,
int(width) // self.vae_scale_factor_spatial,
num_channels_latents,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if visual_cond:
# For visual conditioning, concatenate with zeros and mask
visual_cond = torch.zeros_like(latents)
visual_cond_mask = torch.zeros(
[batch_size, num_latent_frames, int(height) // self.vae_scale_factor_spatial, int(width) // self.vae_scale_factor_spatial, 1],
dtype=latents.dtype,
device=latents.device
)
latents = torch.cat([latents, visual_cond, visual_cond_mask], dim=-1)
return latents
def get_velocity(
self,
latents: torch.Tensor,
timestep: torch.Tensor,
text_embeds: Dict[str, torch.Tensor],
negative_text_embeds: Optional[Dict[str, torch.Tensor]],
visual_rope_pos: List[torch.Tensor],
text_rope_pos: torch.Tensor,
negative_text_rope_pos: torch.Tensor,
guidance_scale: float,
sparse_params: Optional[Dict] = None,
):
# print(latents.shape, text_embeds["text_embeds"].shape, text_embeds["pooled_embed"].shape, timestep.shape, [el.shape for el in visual_rope_pos], text_rope_pos, sparse_params)
pred_velocity = self.transformer(
latents,
text_embeds["text_embeds"],
text_embeds["pooled_embed"],
timestep * 1000, # Scale to match training
visual_rope_pos,
text_rope_pos,
scale_factor=(1, 2, 2), # From Kandinsky config
sparse_params=sparse_params,
return_dict=False
)[0]
if guidance_scale > 1.0 and negative_text_embeds is not None:
uncond_pred_velocity = self.transformer(
latents,
negative_text_embeds["text_embeds"],
negative_text_embeds["pooled_embed"],
timestep * 1000,
visual_rope_pos,
negative_text_rope_pos,
scale_factor=(1, 2, 2),
sparse_params=sparse_params,
return_dict=False
)[0]
pred_velocity = uncond_pred_velocity + guidance_scale * (
pred_velocity - uncond_pred_velocity
)
return pred_velocity
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 512,
width: int = 768,
num_frames: int = 25,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
scheduler_scale: float = 10.0,
num_videos_per_prompt: int = 1,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the video generation.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to avoid during video generation.
height (`int`, defaults to `512`):
The height in pixels of the generated video.
width (`int`, defaults to `768`):
The width in pixels of the generated video.
num_frames (`int`, defaults to `25`):
The number of frames in the generated video.
num_inference_steps (`int`, defaults to `50`):
The number of denoising steps.
guidance_scale (`float`, defaults to `5.0`):
Guidance scale as defined in classifier-free guidance.
scheduler_scale (`float`, defaults to `10.0`):
Scale factor for the custom flow matching scheduler.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator`, *optional*):
A torch generator to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated video.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`KandinskyPipelineOutput`].
callback_on_step_end (`Callable`, *optional*):
A function that is called at the end of each denoising step.
Examples:
Returns:
[`~KandinskyPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`KandinskyPipelineOutput`] is returned, otherwise a `tuple` is returned where
the first element is a list with the generated images and the second element is a list of `bool`s
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""
# 1. Check inputs
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
# 2. Define call parameters
if isinstance(prompt, str):
batch_size = 1
else:
batch_size = len(prompt)
device = self._execution_device
do_classifier_free_guidance = guidance_scale > 1.0
if num_frames % self.vae_scale_factor_temporal != 1:
logger.warning(
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
)
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
num_frames = max(num_frames, 1)
# 3. Encode input prompt
text_embeds, negative_text_embeds, prompt_cu_seqlens, negative_cu_seqlens = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
device=device,
)
# 4. Prepare timesteps (Kandinsky uses custom flow matching)
timesteps = torch.linspace(1, 0, num_inference_steps + 1, device=device)
timesteps = scheduler_scale * timesteps / (1 + (scheduler_scale - 1) * timesteps)
# 5. Prepare latent variables
num_channels_latents = 16
latents = self.prepare_latents(
batch_size=batch_size * num_videos_per_prompt,
num_channels_latents=16,
height=height,
width=width,
num_frames=num_frames,
visual_cond=self.transformer.visual_cond,
dtype=self.transformer.dtype,
device=device,
generator=generator,
latents=latents,
)
# 6. Prepare rope positions
visual_rope_pos = [
torch.arange(num_frames // 4 + 1, device=device),
torch.arange(height // 8 // 2, device=device), # patch size 2
torch.arange(width // 8 // 2, device=device),
]
text_rope_pos = torch.arange(prompt_cu_seqlens[-1].item(), device=device)
negative_text_rope_pos = (
torch.arange(negative_cu_seqlens[-1].item(), device=device)
if negative_cu_seqlens is not None
else None
)
# 7. Prepare sparse attention params if needed
sparse_params = None # Can be extended based on Kandinsky attention config
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, (timestep, timestep_diff) in enumerate(zip(timesteps[:-1], torch.diff(timesteps))):
# Expand timestep to match batch size
time = timestep.unsqueeze(0)
pred_velocity = self.get_velocity(
latents,
time,
text_embeds,
negative_text_embeds,
visual_rope_pos,
text_rope_pos,
negative_text_rope_pos,
guidance_scale,
sparse_params,
)
# Update latents using flow matching
latents[:, :, :, :, :16] = latents[:, :, :, :, :16] + timestep_diff * pred_velocity
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, timestep, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % 1 == 0):
progress_bar.update()
latents = latents[:, :, :, :, :16]
# 9. Decode latents to video
if output_type != "latent":
latents = latents.to(self.vae.dtype)
# Reshape and normalize latents
video = latents.reshape(
batch_size,
num_videos_per_prompt,
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
height // 8,
width // 8,
16,
)
video = video.permute(0, 1, 5, 2, 3, 4) # [batch, num_videos, channels, frames, height, width]
video = video.reshape(batch_size * num_videos_per_prompt, 16, (num_frames - 1) // self.vae_scale_factor_temporal + 1, height // 8, width // 8)
# Normalize and decode
video = video / self.vae.config.scaling_factor
video = self.vae.decode(video).sample
video = ((video.clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8)
# Convert to output format
if output_type == "pil":
if num_frames == 1:
# Single image
video = [ToPILImage()(frame.squeeze(1)) for frame in video]
else:
# Video frames
video = [video[i] for i in range(video.shape[0])]
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return KandinskyPipelineOutput(frames=video)

View File

@@ -0,0 +1,20 @@
from dataclasses import dataclass
import torch
from diffusers.utils import BaseOutput
@dataclass
class KandinskyPipelineOutput(BaseOutput):
r"""
Output class for Wan pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: torch.Tensor