From d53f848720a03423bb9998e75a30b4c3cd04e96d Mon Sep 17 00:00:00 2001 From: leffff Date: Sat, 4 Oct 2025 10:10:23 +0000 Subject: [PATCH] add transformer pipeline first version --- src/diffusers/__init__.py | 4 + src/diffusers/loaders/__init__.py | 2 + src/diffusers/loaders/lora_pipeline.py | 288 +++++++- src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_kandinsky.py | 630 ++++++++++++++++++ src/diffusers/pipelines/__init__.py | 2 + .../pipelines/kandinsky5/__init__.py | 48 ++ .../kandinsky5/pipeline_kandinsky.py | 545 +++++++++++++++ .../pipelines/kandinsky5/pipeline_output.py | 20 + 10 files changed, 1541 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/models/transformers/transformer_kandinsky.py create mode 100644 src/diffusers/pipelines/kandinsky5/__init__.py create mode 100644 src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py create mode 100644 src/diffusers/pipelines/kandinsky5/pipeline_output.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8867250ded..19670053a3 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -260,6 +260,7 @@ else: "VQModel", "WanTransformer3DModel", "WanVACETransformer3DModel", + "Kandinsky5Transformer3DModel", "attention_backend", ] ) @@ -618,6 +619,7 @@ else: "WanPipeline", "WanVACEPipeline", "WanVideoToVideoPipeline", + "Kandinsky5T2VPipeline", "WuerstchenCombinedPipeline", "WuerstchenDecoderPipeline", "WuerstchenPriorPipeline", @@ -947,6 +949,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: VQModel, WanTransformer3DModel, WanVACETransformer3DModel, + Kandinsky5Transformer3DModel, attention_backend, ) from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks @@ -1275,6 +1278,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline, + Kandinsky5T2VPipeline, WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, WuerstchenPriorPipeline, diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 7425486538..6a48ac1b0d 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -77,6 +77,7 @@ if is_torch_available(): "SanaLoraLoaderMixin", "Lumina2LoraLoaderMixin", "WanLoraLoaderMixin", + "KandinskyLoraLoaderMixin", "HiDreamImageLoraLoaderMixin", "SkyReelsV2LoraLoaderMixin", "QwenImageLoraLoaderMixin", @@ -126,6 +127,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: StableDiffusionLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, WanLoraLoaderMixin, + KandinskyLoraLoaderMixin ) from .single_file import FromSingleFileMixin from .textual_inversion import TextualInversionLoaderMixin diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index e25a29e1c0..ea1b92c68b 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -3638,6 +3638,292 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin): """ super().unfuse_lora(components=components, **kwargs) + +class KandinskyLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`Kandinsky5Transformer3DModel`], + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + - A string, the *model id* of a pretrained model hosted on the Hub. + - A path to a *directory* containing the model weights. + - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository. + weight_name (`str`, *optional*, defaults to None): + Name of the serialized state dict file. + use_safetensors (`bool`, *optional*): + Whether to use safetensors for loading. + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata. + """ + # Load the main state dict first which has the LoRA layers + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + + state_dict, metadata = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out + + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. + hotswap (`bool`, *optional*): + Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + kwargs (`dict`, *optional*): + See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + # Load LoRA into transformer + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + def load_lora_into_transformer( + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, + ): + """ + Load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. + transformer (`Kandinsky5Transformer3DModel`): + The transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights. + hotswap (`bool`, *optional*): + See [`~loaders.KandinskyLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. + """ + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + + @classmethod + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + transformer_lora_adapter_metadata=None, + ): + r""" + Save the LoRA parameters corresponding to the transformer and text encoders. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process. + save_function (`Callable`): + The function to use to save the state dictionary. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer. + """ + lora_layers = {} + lora_metadata = {} + + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata + + if not lora_layers: + raise ValueError( + "You must pass at least one of `transformer_lora_layers`" + ) + + cls._save_lora_weights( + save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. + + Example: + ```py + from diffusers import Kandinsky5T2VPipeline + + pipeline = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V") + pipeline.load_lora_weights("path/to/lora.safetensors") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + r""" + Reverses the effect of [`pipe.fuse_lora()`]. + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + """ + super().unfuse_lora(components=components, **kwargs) + class WanLoraLoaderMixin(LoraBaseMixin): r""" @@ -4802,4 +5088,4 @@ class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): def __init__(self, *args, **kwargs): deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead." deprecate("LoraLoaderMixin", "1.0.0", deprecation_message) - super().__init__(*args, **kwargs) + super().__init__(*args, **kwargs) \ No newline at end of file diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 457f70448a..89ca9d3977 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -101,6 +101,7 @@ if is_torch_available(): _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"] _import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"] + _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"] @@ -200,6 +201,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: TransformerTemporalModel, WanTransformer3DModel, WanVACETransformer3DModel, + Kandinsky5Transformer3DModel, ) from .unets import ( I2VGenXLUNet, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index b60f0636e6..4b9911f9cb 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -37,3 +37,4 @@ if is_torch_available(): from .transformer_temporal import TransformerTemporalModel from .transformer_wan import WanTransformer3DModel from .transformer_wan_vace import WanVACETransformer3DModel + from .transformer_kandinsky import Kandinsky5Transformer3DModel diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py new file mode 100644 index 0000000000..a057cc13cc --- /dev/null +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -0,0 +1,630 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Dict, Optional, Tuple, Union, List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph +from .._modeling_parallel import ContextParallelInput, ContextParallelOutput +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..cache_utils import CacheMixin +from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import FP32LayerNorm + +logger = logging.get_logger(__name__) + + +# @torch.compile() +@torch.autocast(device_type="cuda", dtype=torch.float32) +def apply_scale_shift_norm(norm, x, scale, shift): + return (norm(x) * (scale + 1.0) + shift).to(torch.bfloat16) + +# @torch.compile() +@torch.autocast(device_type="cuda", dtype=torch.float32) +def apply_gate_sum(x, out, gate): + return (x + gate * out).to(torch.bfloat16) + +# @torch.compile() +@torch.autocast(device_type="cuda", enabled=False) +def apply_rotary(x, rope): + x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32) + x_out = (rope * x_).sum(dim=-1) + return x_out.reshape(*x.shape).to(torch.bfloat16) + + +@torch.autocast(device_type="cuda", enabled=False) +def get_freqs(dim, max_period=10000.0): + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=dim, dtype=torch.float32) + / dim + ) + return freqs + + +class TimeEmbeddings(nn.Module): + def __init__(self, model_dim, time_dim, max_period=10000.0): + super().__init__() + assert model_dim % 2 == 0 + self.model_dim = model_dim + self.max_period = max_period + self.register_buffer( + "freqs", get_freqs(model_dim // 2, max_period), persistent=False + ) + self.in_layer = nn.Linear(model_dim, time_dim, bias=True) + self.activation = nn.SiLU() + self.out_layer = nn.Linear(time_dim, time_dim, bias=True) + + def forward(self, time): + args = torch.outer(time, self.freqs.to(device=time.device)) + time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) + return time_embed + + +class TextEmbeddings(nn.Module): + def __init__(self, text_dim, model_dim): + super().__init__() + self.in_layer = nn.Linear(text_dim, model_dim, bias=True) + self.norm = nn.LayerNorm(model_dim, elementwise_affine=True) + + def forward(self, text_embed): + text_embed = self.in_layer(text_embed) + return self.norm(text_embed).type_as(text_embed) + + +class VisualEmbeddings(nn.Module): + def __init__(self, visual_dim, model_dim, patch_size): + super().__init__() + self.patch_size = patch_size + self.in_layer = nn.Linear(math.prod(patch_size) * visual_dim, model_dim) + + def forward(self, x): + batch_size, duration, height, width, dim = x.shape + x = ( + x.view( + batch_size, + duration // self.patch_size[0], + self.patch_size[0], + height // self.patch_size[1], + self.patch_size[1], + width // self.patch_size[2], + self.patch_size[2], + dim, + ) + .permute(0, 1, 3, 5, 2, 4, 6, 7) + .flatten(4, 7) + ) + return self.in_layer(x) + + +class RoPE1D(nn.Module): + """ + 1D Rotary Positional Embeddings for text sequences. + + Args: + dim: Dimension of the rotary embeddings + max_pos: Maximum sequence length + max_period: Maximum period for sinusoidal embeddings + """ + + def __init__(self, dim, max_pos=1024, max_period=10000.0): + super().__init__() + self.max_period = max_period + self.dim = dim + self.max_pos = max_pos + freq = get_freqs(dim // 2, max_period) + pos = torch.arange(max_pos, dtype=freq.dtype) + self.register_buffer("args", torch.outer(pos, freq), persistent=False) + + def forward(self, pos): + """ + Args: + pos: Position indices of shape [seq_len] or [batch_size, seq_len] + + Returns: + Rotary embeddings of shape [seq_len, 1, 2, 2] + """ + args = self.args[pos] + cosine = torch.cos(args) + sine = torch.sin(args) + rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) + rope = rope.view(*rope.shape[:-1], 2, 2) + return rope.unsqueeze(-4) + + +class RoPE3D(nn.Module): + def __init__(self, axes_dims, max_pos=(128, 128, 128), max_period=10000.0): + super().__init__() + self.axes_dims = axes_dims + self.max_pos = max_pos + self.max_period = max_period + + for i, (axes_dim, ax_max_pos) in enumerate(zip(axes_dims, max_pos)): + freq = get_freqs(axes_dim // 2, max_period) + pos = torch.arange(ax_max_pos, dtype=freq.dtype) + self.register_buffer(f"args_{i}", torch.outer(pos, freq), persistent=False) + + @torch.autocast(device_type="cuda", enabled=False) + def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)): + batch_size, duration, height, width = shape + args_t = self.args_0[pos[0]] / scale_factor[0] + args_h = self.args_1[pos[1]] / scale_factor[1] + args_w = self.args_2[pos[2]] / scale_factor[2] + + # Replicate the original logic with batch dimension + args_t_expanded = args_t.view(1, duration, 1, 1, -1).expand(batch_size, -1, height, width, -1) + args_h_expanded = args_h.view(1, 1, height, 1, -1).expand(batch_size, duration, -1, width, -1) + args_w_expanded = args_w.view(1, 1, 1, width, -1).expand(batch_size, duration, height, -1, -1) + + # Concatenate along the last dimension + args = torch.cat([args_t_expanded, args_h_expanded, args_w_expanded], dim=-1) # [B, D, H, W, F] + + cosine = torch.cos(args) + sine = torch.sin(args) + rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) # [B, D, H, W, F, 4] + rope = rope.view(*rope.shape[:-1], 2, 2) # [B, D, H, W, F, 2, 2] + return rope.unsqueeze(-4) # [B, D, H, 1, W, F, 2, 2] + + +class Modulation(nn.Module): + def __init__(self, time_dim, model_dim, num_params): + super().__init__() + self.activation = nn.SiLU() + self.out_layer = nn.Linear(time_dim, num_params * model_dim) + self.out_layer.weight.data.zero_() + self.out_layer.bias.data.zero_() + + def forward(self, x): + return self.out_layer(self.activation(x)) + + +class MultiheadSelfAttentionEnc(nn.Module): + def __init__(self, num_channels, head_dim): + super().__init__() + assert num_channels % head_dim == 0 + self.num_heads = num_channels // head_dim + + self.to_query = nn.Linear(num_channels, num_channels, bias=True) + self.to_key = nn.Linear(num_channels, num_channels, bias=True) + self.to_value = nn.Linear(num_channels, num_channels, bias=True) + self.query_norm = nn.RMSNorm(head_dim) + self.key_norm = nn.RMSNorm(head_dim) + self.out_layer = nn.Linear(num_channels, num_channels, bias=True) + + def forward(self, x, rope): + query = self.to_query(x) + key = self.to_key(x) + value = self.to_value(x) + + shape = query.shape[:-1] + query = query.reshape(*shape, self.num_heads, -1) + key = key.reshape(*shape, self.num_heads, -1) + value = value.reshape(*shape, self.num_heads, -1) + + query = self.query_norm(query.float()).type_as(query) + key = self.key_norm(key.float()).type_as(key) + + query = apply_rotary(query, rope).type_as(query) + key = apply_rotary(key, rope).type_as(key) + + # Use torch's scaled_dot_product_attention + out = F.scaled_dot_product_attention( + query, + key, + value, + ).flatten(-2, -1) + + out = self.out_layer(out) + return out + + +class MultiheadSelfAttentionDec(nn.Module): + def __init__(self, num_channels, head_dim): + super().__init__() + assert num_channels % head_dim == 0 + self.num_heads = num_channels // head_dim + + self.to_query = nn.Linear(num_channels, num_channels, bias=True) + self.to_key = nn.Linear(num_channels, num_channels, bias=True) + self.to_value = nn.Linear(num_channels, num_channels, bias=True) + self.query_norm = nn.RMSNorm(head_dim) + self.key_norm = nn.RMSNorm(head_dim) + self.out_layer = nn.Linear(num_channels, num_channels, bias=True) + + def forward(self, x, rope, sparse_params=None): + query = self.to_query(x) + key = self.to_key(x) + value = self.to_value(x) + + shape = query.shape[:-1] + query = query.reshape(*shape, self.num_heads, -1) + key = key.reshape(*shape, self.num_heads, -1) + value = value.reshape(*shape, self.num_heads, -1) + + query = self.query_norm(query.float()).type_as(query) + key = self.key_norm(key.float()).type_as(key) + + query = apply_rotary(query, rope).type_as(query) + key = apply_rotary(key, rope).type_as(key) + + # Use standard attention (can be extended with sparse attention) + out = F.scaled_dot_product_attention( + query, + key, + value, + ).flatten(-2, -1) + + out = self.out_layer(out) + return out + + +class MultiheadCrossAttention(nn.Module): + def __init__(self, num_channels, head_dim): + super().__init__() + assert num_channels % head_dim == 0 + self.num_heads = num_channels // head_dim + + self.to_query = nn.Linear(num_channels, num_channels, bias=True) + self.to_key = nn.Linear(num_channels, num_channels, bias=True) + self.to_value = nn.Linear(num_channels, num_channels, bias=True) + self.query_norm = nn.RMSNorm(head_dim) + self.key_norm = nn.RMSNorm(head_dim) + self.out_layer = nn.Linear(num_channels, num_channels, bias=True) + + def forward(self, x, cond): + query = self.to_query(x) + key = self.to_key(cond) + value = self.to_value(cond) + + shape, cond_shape = query.shape[:-1], key.shape[:-1] + query = query.reshape(*shape, self.num_heads, -1) + key = key.reshape(*cond_shape, self.num_heads, -1) + value = value.reshape(*cond_shape, self.num_heads, -1) + + query = self.query_norm(query.float()).type_as(query) + key = self.key_norm(key.float()).type_as(key) + + out = F.scaled_dot_product_attention( + query.permute(0, 2, 1, 3), + key.permute(0, 2, 1, 3), + value.permute(0, 2, 1, 3), + ).permute(0, 2, 1, 3).flatten(-2, -1) + + out = self.out_layer(out) + return out + + +class FeedForward(nn.Module): + def __init__(self, dim, ff_dim): + super().__init__() + self.in_layer = nn.Linear(dim, ff_dim, bias=False) + self.activation = nn.GELU() + self.out_layer = nn.Linear(ff_dim, dim, bias=False) + + def forward(self, x): + return self.out_layer(self.activation(self.in_layer(x))) + + +class TransformerEncoderBlock(nn.Module): + def __init__(self, model_dim, time_dim, ff_dim, head_dim): + super().__init__() + self.text_modulation = Modulation(time_dim, model_dim, 6) + + self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) + self.self_attention = MultiheadSelfAttentionEnc(model_dim, head_dim) + + self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False) + self.feed_forward = FeedForward(model_dim, ff_dim) + + def forward(self, x, time_embed, rope): + self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1) + shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) + + out = self.self_attention_norm(x) + out = out * (scale + 1.0) + shift + out = self.self_attention(out, rope) + x = x + gate * out + + shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) + out = self.feed_forward_norm(x) + out = out * (scale + 1.0) + shift + out = self.feed_forward(out) + x = x + gate * out + return x + + +class TransformerDecoderBlock(nn.Module): + def __init__(self, model_dim, time_dim, ff_dim, head_dim): + super().__init__() + self.visual_modulation = Modulation(time_dim, model_dim, 9) + + self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) + self.self_attention = MultiheadSelfAttentionDec(model_dim, head_dim) + + self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) + self.cross_attention = MultiheadCrossAttention(model_dim, head_dim) + + self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False) + self.feed_forward = FeedForward(model_dim, ff_dim) + + def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): + self_attn_params, cross_attn_params, ff_params = torch.chunk( + self.visual_modulation(time_embed), 3, dim=-1 + ) + shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) + + visual_out = self.self_attention_norm(visual_embed) + visual_out = visual_out * (scale + 1.0) + shift + visual_out = self.self_attention(visual_out, rope, sparse_params) + visual_embed = visual_embed + gate * visual_out + + shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1) + visual_out = self.cross_attention_norm(visual_embed) + visual_out = visual_out * (scale + 1.0) + shift + visual_out = self.cross_attention(visual_out, text_embed) + visual_embed = visual_embed + gate * visual_out + + shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) + visual_out = self.feed_forward_norm(visual_embed) + visual_out = visual_out * (scale + 1.0) + shift + visual_out = self.feed_forward(visual_out) + visual_embed = visual_embed + gate * visual_out + return visual_embed + + +class OutLayer(nn.Module): + def __init__(self, model_dim, time_dim, visual_dim, patch_size): + super().__init__() + self.patch_size = patch_size + self.modulation = Modulation(time_dim, model_dim, 2) + self.norm = nn.LayerNorm(model_dim, elementwise_affine=False) + self.out_layer = nn.Linear( + model_dim, math.prod(patch_size) * visual_dim, bias=True + ) + + def forward(self, visual_embed, text_embed, time_embed): + # Handle the new batch dimension: [batch, duration, height, width, model_dim] + batch_size, duration, height, width, _ = visual_embed.shape + + shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1) + + # Apply modulation with proper broadcasting for the new shape + visual_embed = apply_scale_shift_norm( + self.norm, + visual_embed, + scale[:, None, None, None], # [batch, 1, 1, 1, model_dim] -> [batch, 1, 1, 1] + shift[:, None, None, None], # [batch, 1, 1, 1, model_dim] -> [batch, 1, 1, 1] + ).type_as(visual_embed) + + x = self.out_layer(visual_embed) + + # Now x has shape [batch, duration, height, width, patch_prod * visual_dim] + x = ( + x.view( + batch_size, + duration, + height, + width, + -1, + self.patch_size[0], + self.patch_size[1], + self.patch_size[2], + ) + .permute(0, 5, 1, 6, 2, 7, 3, 4) # [batch, patch_t, duration, patch_h, height, patch_w, width, features] + .flatten(1, 2) # [batch, patch_t * duration, height, patch_w, width, features] + .flatten(2, 3) # [batch, patch_t * duration, patch_h * height, width, features] + .flatten(3, 4) # [batch, patch_t * duration, patch_h * height, patch_w * width] + ) + return x + + +@maybe_allow_in_graph +class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin): + r""" + A 3D Transformer model for video generation used in Kandinsky 5.0. + + This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods implemented for all models (such as downloading or saving). + + Args: + in_visual_dim (`int`, defaults to 16): + Number of channels in the input visual latent. + out_visual_dim (`int`, defaults to 16): + Number of channels in the output visual latent. + time_dim (`int`, defaults to 512): + Dimension of the time embeddings. + patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): + Patch size for the visual embeddings (temporal, height, width). + model_dim (`int`, defaults to 1792): + Hidden dimension of the transformer model. + ff_dim (`int`, defaults to 7168): + Intermediate dimension of the feed-forward networks. + num_text_blocks (`int`, defaults to 2): + Number of transformer blocks in the text encoder. + num_visual_blocks (`int`, defaults to 32): + Number of transformer blocks in the visual decoder. + axes_dims (`Tuple[int]`, defaults to `(16, 24, 24)`): + Dimensions for the rotary positional embeddings (temporal, height, width). + visual_cond (`bool`, defaults to `True`): + Whether to use visual conditioning (for image/video conditioning). + in_text_dim (`int`, defaults to 3584): + Dimension of the text embeddings from Qwen2.5-VL. + in_text_dim2 (`int`, defaults to 768): + Dimension of the pooled text embeddings from CLIP. + """ + + @register_to_config + def __init__( + self, + in_visual_dim: int = 16, + out_visual_dim: int = 16, + time_dim: int = 512, + patch_size: Tuple[int, int, int] = (1, 2, 2), + model_dim: int = 1792, + ff_dim: int = 7168, + num_text_blocks: int = 2, + num_visual_blocks: int = 32, + axes_dims: Tuple[int, int, int] = (16, 24, 24), + visual_cond: bool = True, + in_text_dim: int = 3584, + in_text_dim2: int = 768, + ): + super().__init__() + + self.in_visual_dim = in_visual_dim + self.model_dim = model_dim + self.patch_size = patch_size + self.visual_cond = visual_cond + + # Calculate head dimension for attention + head_dim = sum(axes_dims) + + # Determine visual embedding dimension based on conditioning + visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim + + # 1. Embedding layers + self.time_embeddings = TimeEmbeddings(model_dim, time_dim) + self.text_embeddings = TextEmbeddings(in_text_dim, model_dim) + self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim) + self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size) + + # 2. Rotary positional embeddings + self.text_rope_embeddings = RoPE1D(head_dim) + self.visual_rope_embeddings = RoPE3D(axes_dims) + + # 3. Transformer blocks + self.text_transformer_blocks = nn.ModuleList([ + TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim) + for _ in range(num_text_blocks) + ]) + + self.visual_transformer_blocks = nn.ModuleList([ + TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim) + for _ in range(num_visual_blocks) + ]) + + # 4. Output layer + self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + pooled_text_embed: torch.Tensor, + timestep: torch.Tensor, + visual_rope_pos: List[torch.Tensor], + text_rope_pos: torch.Tensor, + scale_factor: Tuple[float, float, float] = (1.0, 1.0, 1.0), + sparse_params: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + Forward pass of the Kandinsky 5.0 3D Transformer. + + Args: + hidden_states (`torch.Tensor`): + Input visual latent tensor of shape `(batch_size, num_frames, height, width, channels)`. + encoder_hidden_states (`torch.Tensor`): + Text embeddings from Qwen2.5-VL of shape `(batch_size, sequence_length, text_dim)`. + pooled_text_embed (`torch.Tensor`): + Pooled text embeddings from CLIP of shape `(batch_size, pooled_text_dim)`. + timestep (`torch.Tensor`): + Timestep tensor of shape `(batch_size,)` or `(batch_size * num_frames,)`. + visual_rope_pos (`List[torch.Tensor]`): + List of tensors for visual rotary positional embeddings [temporal, height, width]. + text_rope_pos (`torch.Tensor`): + Tensor for text rotary positional embeddings. + scale_factor (`Tuple[float, float, float]`, defaults to `(1.0, 1.0, 1.0)`): + Scale factors for rotary positional embeddings. + sparse_params (`Dict[str, Any]`, *optional*): + Parameters for sparse attention. + return_dict (`bool`, defaults to `True`): + Whether to return a dictionary or a tensor. + + Returns: + [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: + If `return_dict` is `True`, a [`~models.transformer_2d.Transformer2DModelOutput`] is returned, + otherwise a `tuple` where the first element is the sample tensor. + """ + batch_size, num_frames, height, width, channels = hidden_states.shape + + # 1. Process text embeddings + text_embed = self.text_embeddings(encoder_hidden_states) + time_embed = self.time_embeddings(timestep) + + # Add pooled text embedding to time embedding + pooled_embed = self.pooled_text_embeddings(pooled_text_embed) + time_embed = time_embed + pooled_embed + + # visual_embed shape: [batch_size, seq_len, model_dim] + visual_embed = self.visual_embeddings(hidden_states) + + # 3. Text rotary embeddings + text_rope = self.text_rope_embeddings(text_rope_pos) + + # 4. Text transformer blocks + for text_block in self.text_transformer_blocks: + if self.gradient_checkpointing and self.training: + text_embed = torch.utils.checkpoint.checkpoint( + text_block, text_embed, time_embed, text_rope, use_reentrant=False + ) + else: + text_embed = text_block(text_embed, time_embed, text_rope) + + # 5. Prepare visual rope + visual_shape = visual_embed.shape[:-1] + visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor) + + visual_embed = visual_embed.reshape(visual_embed.shape[0], -1, visual_embed.shape[-1]) + visual_rope = visual_rope.view(visual_rope.shape[0], -1, *list(visual_rope.shape[-4:])) + + # 6. Visual transformer blocks + for visual_block in self.visual_transformer_blocks: + if self.gradient_checkpointing and self.training: + visual_embed = torch.utils.checkpoint.checkpoint( + visual_block, + visual_embed, + text_embed, + time_embed, + visual_rope, + # visual_rope_flat, + sparse_params, + use_reentrant=False, + ) + else: + visual_embed = visual_block( + visual_embed, text_embed, time_embed, visual_rope, sparse_params + ) + + # 7. Output projection + visual_embed = visual_embed.reshape(batch_size, num_frames, height // 2, width // 2, -1) + output = self.out_layer(visual_embed, text_embed, time_embed) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 190c7871d2..201d92afb0 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -382,6 +382,7 @@ else: "WuerstchenPriorPipeline", ] _import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline", "WanVACEPipeline"] + _import_structure["kandinsky5"] = ["Kandinsky5T2VPipeline"] _import_structure["skyreels_v2"] = [ "SkyReelsV2DiffusionForcingPipeline", "SkyReelsV2DiffusionForcingImageToVideoPipeline", @@ -787,6 +788,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ) from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline from .wan import WanImageToVideoPipeline, WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline + from .kandinsky5 import Kandinsky5T2VPipeline from .wuerstchen import ( WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, diff --git a/src/diffusers/pipelines/kandinsky5/__init__.py b/src/diffusers/pipelines/kandinsky5/__init__.py new file mode 100644 index 0000000000..af8e124217 --- /dev/null +++ b/src/diffusers/pipelines/kandinsky5/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_kandinsky"] = ["Kandinsky5T2VPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_kandinsky import Kandinsky5T2VPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py new file mode 100644 index 0000000000..02eae13633 --- /dev/null +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -0,0 +1,545 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import Any, Callable, Dict, List, Optional, Union + +import regex as re +import torch +from transformers import Qwen2TokenizerFast, Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, AutoProcessor, CLIPTextModel, CLIPTokenizer +import torchvision +from torchvision.transforms import ToPILImage + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import KandinskyLoraLoaderMixin +from ...models import AutoencoderKLHunyuanVideo +from ...models.transformers import Kandinsky5Transformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import KandinskyPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +logger = logging.get_logger(__name__) + +EXAMPLE_DOC_STRING = """ + Examples: + + ```python + >>> import torch + >>> from diffusers import Kandinsky5T2VPipeline, Kandinsky5Transformer3DModel + >>> from diffusers.utils import export_to_video + + >>> pipe = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V") + >>> pipe = pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen." + >>> negative_prompt = "Bright tones, overexposed, static, blurred details" + + >>> output = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=512, + ... width=768, + ... num_frames=25, + ... num_inference_steps=50, + ... guidance_scale=5.0, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=6) + ``` +""" + + +class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using Kandinsky 5.0. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + transformer ([`Kandinsky5Transformer3DModel`]): + Conditional Transformer to denoise the encoded video latents. + vae ([`AutoencoderKLHunyuanVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`Qwen2_5_VLForConditionalGeneration`]): + Frozen text-encoder (Qwen2.5-VL). + tokenizer ([`AutoProcessor`]): + Tokenizer for Qwen2.5-VL. + text_encoder_2 ([`CLIPTextModel`]): + Frozen CLIP text encoder. + tokenizer_2 ([`CLIPTokenizer`]): + Tokenizer for CLIP. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + transformer: Kandinsky5Transformer3DModel, + vae: AutoencoderKLHunyuanVideo, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2VLProcessor, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + transformer=transformer, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + scheduler=scheduler, + ) + + self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio + self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio + + def _encode_prompt_qwen( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 256, + ): + device = device or self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + + # Kandinsky specific prompt template + prompt_template = "\n".join([ + "<|im_start|>system\nYou are a promt engineer. Describe the video in detail.", + "Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.", + "Describe the location of the video, main characters or objects and their action.", + "Describe the dynamism of the video and presented actions.", + "Name the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.", + "Describe the visual effects, postprocessing and transitions if they are presented in the video.", + "Pay attention to the order of key actions shown in the scene.<|im_end|>", + "<|im_start|>user\n{}<|im_end|>", + ]) + crop_start = 129 + + full_texts = [prompt_template.format(p) for p in prompt] + + inputs = self.tokenizer( + text=full_texts, + images=None, + videos=None, + max_length=max_sequence_length + crop_start, + truncation=True, + return_tensors="pt", + padding=True, + ).to(device) + + with torch.no_grad(): + embeds = self.text_encoder( + input_ids=inputs["input_ids"], + return_dict=True, + output_hidden_states=True, + )["hidden_states"][-1][:, crop_start:] + + attention_mask = inputs["attention_mask"][:, crop_start:] + embeds = embeds[attention_mask.bool()] + cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) + cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to(dtype=torch.int32) + + # duplicate for each generation per prompt + batch_size = len(prompt) + seq_len = embeds.shape[0] // batch_size + embeds = embeds.reshape(batch_size, seq_len, -1) + embeds = embeds.repeat(1, num_videos_per_prompt, 1) + embeds = embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return embeds, cu_seqlens + + def _encode_prompt_clip( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_videos_per_prompt: int = 1, + ): + device = device or self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + + inputs = self.tokenizer_2( + prompt, + max_length=77, + truncation=True, + add_special_tokens=True, + padding="max_length", + return_tensors="pt", + ).to(device) + + with torch.no_grad(): + pooled_embed = self.text_encoder_2(**inputs)["pooler_output"] + + # duplicate for each generation per prompt + batch_size = len(prompt) + pooled_embed = pooled_embed.repeat(1, num_videos_per_prompt, 1) + pooled_embed = pooled_embed.view(batch_size * num_videos_per_prompt, -1) + + return pooled_embed + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + # Encode with Qwen2.5-VL + prompt_embeds, prompt_cu_seqlens = self._encode_prompt_qwen( + prompt, device, num_videos_per_prompt + ) + pooled_embed = self._encode_prompt_clip(prompt, device, num_videos_per_prompt) + + if do_classifier_free_guidance: + negative_prompt = negative_prompt or "" + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + negative_prompt_embeds, negative_cu_seqlens = self._encode_prompt_qwen( + negative_prompt, device, num_videos_per_prompt + ) + negative_pooled_embed = self._encode_prompt_clip(negative_prompt, device, num_videos_per_prompt) + else: + negative_prompt_embeds = None + negative_pooled_embed = None + negative_cu_seqlens = None + + text_embeds = { + "text_embeds": prompt_embeds, + "pooled_embed": pooled_embed, + } + negative_text_embeds = { + "text_embeds": negative_prompt_embeds, + "pooled_embed": negative_pooled_embed, + } if do_classifier_free_guidance else None + + return text_embeds, negative_text_embeds, prompt_cu_seqlens, negative_cu_seqlens + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + visual_cond: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + num_channels_latents, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + if visual_cond: + # For visual conditioning, concatenate with zeros and mask + visual_cond = torch.zeros_like(latents) + visual_cond_mask = torch.zeros( + [batch_size, num_latent_frames, int(height) // self.vae_scale_factor_spatial, int(width) // self.vae_scale_factor_spatial, 1], + dtype=latents.dtype, + device=latents.device + ) + latents = torch.cat([latents, visual_cond, visual_cond_mask], dim=-1) + + return latents + + def get_velocity( + self, + latents: torch.Tensor, + timestep: torch.Tensor, + text_embeds: Dict[str, torch.Tensor], + negative_text_embeds: Optional[Dict[str, torch.Tensor]], + visual_rope_pos: List[torch.Tensor], + text_rope_pos: torch.Tensor, + negative_text_rope_pos: torch.Tensor, + guidance_scale: float, + sparse_params: Optional[Dict] = None, + ): + # print(latents.shape, text_embeds["text_embeds"].shape, text_embeds["pooled_embed"].shape, timestep.shape, [el.shape for el in visual_rope_pos], text_rope_pos, sparse_params) + + pred_velocity = self.transformer( + latents, + text_embeds["text_embeds"], + text_embeds["pooled_embed"], + timestep * 1000, # Scale to match training + visual_rope_pos, + text_rope_pos, + scale_factor=(1, 2, 2), # From Kandinsky config + sparse_params=sparse_params, + return_dict=False + )[0] + + if guidance_scale > 1.0 and negative_text_embeds is not None: + uncond_pred_velocity = self.transformer( + latents, + negative_text_embeds["text_embeds"], + negative_text_embeds["pooled_embed"], + timestep * 1000, + visual_rope_pos, + negative_text_rope_pos, + scale_factor=(1, 2, 2), + sparse_params=sparse_params, + return_dict=False + )[0] + + pred_velocity = uncond_pred_velocity + guidance_scale * ( + pred_velocity - uncond_pred_velocity + ) + + return pred_velocity + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 768, + num_frames: int = 25, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + scheduler_scale: float = 10.0, + num_videos_per_prompt: int = 1, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to avoid during video generation. + height (`int`, defaults to `512`): + The height in pixels of the generated video. + width (`int`, defaults to `768`): + The width in pixels of the generated video. + num_frames (`int`, defaults to `25`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in classifier-free guidance. + scheduler_scale (`float`, defaults to `10.0`): + Scale factor for the custom flow matching scheduler. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator`, *optional*): + A torch generator to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`KandinskyPipelineOutput`]. + callback_on_step_end (`Callable`, *optional*): + A function that is called at the end of each denoising step. + + Examples: + + Returns: + [`~KandinskyPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`KandinskyPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + # 1. Check inputs + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + # 2. Define call parameters + if isinstance(prompt, str): + batch_size = 1 + else: + batch_size = len(prompt) + + device = self._execution_device + do_classifier_free_guidance = guidance_scale > 1.0 + + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + + # 3. Encode input prompt + text_embeds, negative_text_embeds, prompt_cu_seqlens, negative_cu_seqlens = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + device=device, + ) + + # 4. Prepare timesteps (Kandinsky uses custom flow matching) + timesteps = torch.linspace(1, 0, num_inference_steps + 1, device=device) + timesteps = scheduler_scale * timesteps / (1 + (scheduler_scale - 1) * timesteps) + + # 5. Prepare latent variables + num_channels_latents = 16 + latents = self.prepare_latents( + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=16, + height=height, + width=width, + num_frames=num_frames, + visual_cond=self.transformer.visual_cond, + dtype=self.transformer.dtype, + device=device, + generator=generator, + latents=latents, + ) + + # 6. Prepare rope positions + visual_rope_pos = [ + torch.arange(num_frames // 4 + 1, device=device), + torch.arange(height // 8 // 2, device=device), # patch size 2 + torch.arange(width // 8 // 2, device=device), + ] + + text_rope_pos = torch.arange(prompt_cu_seqlens[-1].item(), device=device) + + negative_text_rope_pos = ( + torch.arange(negative_cu_seqlens[-1].item(), device=device) + if negative_cu_seqlens is not None + else None + ) + + # 7. Prepare sparse attention params if needed + sparse_params = None # Can be extended based on Kandinsky attention config + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, (timestep, timestep_diff) in enumerate(zip(timesteps[:-1], torch.diff(timesteps))): + # Expand timestep to match batch size + time = timestep.unsqueeze(0) + + pred_velocity = self.get_velocity( + latents, + time, + text_embeds, + negative_text_embeds, + visual_rope_pos, + text_rope_pos, + negative_text_rope_pos, + guidance_scale, + sparse_params, + ) + + # Update latents using flow matching + latents[:, :, :, :, :16] = latents[:, :, :, :, :16] + timestep_diff * pred_velocity + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, timestep, callback_kwargs) + latents = callback_outputs.pop("latents", latents) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % 1 == 0): + progress_bar.update() + + latents = latents[:, :, :, :, :16] + + # 9. Decode latents to video + if output_type != "latent": + latents = latents.to(self.vae.dtype) + # Reshape and normalize latents + video = latents.reshape( + batch_size, + num_videos_per_prompt, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + height // 8, + width // 8, + 16, + ) + video = video.permute(0, 1, 5, 2, 3, 4) # [batch, num_videos, channels, frames, height, width] + video = video.reshape(batch_size * num_videos_per_prompt, 16, (num_frames - 1) // self.vae_scale_factor_temporal + 1, height // 8, width // 8) + + # Normalize and decode + video = video / self.vae.config.scaling_factor + video = self.vae.decode(video).sample + video = ((video.clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8) + + # Convert to output format + if output_type == "pil": + if num_frames == 1: + # Single image + video = [ToPILImage()(frame.squeeze(1)) for frame in video] + else: + # Video frames + video = [video[i] for i in range(video.shape[0])] + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return KandinskyPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_output.py b/src/diffusers/pipelines/kandinsky5/pipeline_output.py new file mode 100644 index 0000000000..ed77d42a9a --- /dev/null +++ b/src/diffusers/pipelines/kandinsky5/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class KandinskyPipelineOutput(BaseOutput): + r""" + Output class for Wan pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor