mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
add transformer pipeline first version
This commit is contained in:
@@ -260,6 +260,7 @@ else:
|
||||
"VQModel",
|
||||
"WanTransformer3DModel",
|
||||
"WanVACETransformer3DModel",
|
||||
"Kandinsky5Transformer3DModel",
|
||||
"attention_backend",
|
||||
]
|
||||
)
|
||||
@@ -618,6 +619,7 @@ else:
|
||||
"WanPipeline",
|
||||
"WanVACEPipeline",
|
||||
"WanVideoToVideoPipeline",
|
||||
"Kandinsky5T2VPipeline",
|
||||
"WuerstchenCombinedPipeline",
|
||||
"WuerstchenDecoderPipeline",
|
||||
"WuerstchenPriorPipeline",
|
||||
@@ -947,6 +949,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
VQModel,
|
||||
WanTransformer3DModel,
|
||||
WanVACETransformer3DModel,
|
||||
Kandinsky5Transformer3DModel,
|
||||
attention_backend,
|
||||
)
|
||||
from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks
|
||||
@@ -1275,6 +1278,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
WanPipeline,
|
||||
WanVACEPipeline,
|
||||
WanVideoToVideoPipeline,
|
||||
Kandinsky5T2VPipeline,
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
WuerstchenPriorPipeline,
|
||||
|
||||
@@ -77,6 +77,7 @@ if is_torch_available():
|
||||
"SanaLoraLoaderMixin",
|
||||
"Lumina2LoraLoaderMixin",
|
||||
"WanLoraLoaderMixin",
|
||||
"KandinskyLoraLoaderMixin",
|
||||
"HiDreamImageLoraLoaderMixin",
|
||||
"SkyReelsV2LoraLoaderMixin",
|
||||
"QwenImageLoraLoaderMixin",
|
||||
@@ -126,6 +127,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
WanLoraLoaderMixin,
|
||||
KandinskyLoraLoaderMixin
|
||||
)
|
||||
from .single_file import FromSingleFileMixin
|
||||
from .textual_inversion import TextualInversionLoaderMixin
|
||||
|
||||
@@ -3638,6 +3638,292 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`Kandinsky5Transformer3DModel`],
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["transformer"]
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Return state dict for lora weights and the network alphas.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
- A string, the *model id* of a pretrained model hosted on the Hub.
|
||||
- A path to a *directory* containing the model weights.
|
||||
- A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository.
|
||||
weight_name (`str`, *optional*, defaults to None):
|
||||
Name of the serialized state dict file.
|
||||
use_safetensors (`bool`, *optional*):
|
||||
Whether to use safetensors for loading.
|
||||
return_lora_metadata (`bool`, *optional*, defaults to False):
|
||||
When enabled, additionally return the LoRA adapter metadata.
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
|
||||
state_dict, metadata = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
||||
return out
|
||||
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer`
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model.
|
||||
hotswap (`bool`, *optional*):
|
||||
Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
kwargs["return_lora_metadata"] = True
|
||||
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
# Load LoRA into transformer
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_transformer(
|
||||
cls,
|
||||
state_dict,
|
||||
transformer,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
metadata=None,
|
||||
):
|
||||
"""
|
||||
Load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
Parameters:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters.
|
||||
transformer (`Kandinsky5Transformer3DModel`):
|
||||
The transformer model to load the LoRA layers into.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.KandinskyLoraLoaderMixin.load_lora_weights`].
|
||||
metadata (`dict`):
|
||||
Optional LoRA adapter metadata.
|
||||
"""
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
transformer_lora_adapter_metadata=None,
|
||||
):
|
||||
r"""
|
||||
Save the LoRA parameters corresponding to the transformer and text encoders.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save LoRA parameters to.
|
||||
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `transformer`.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way.
|
||||
transformer_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the transformer.
|
||||
"""
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if transformer_lora_layers:
|
||||
lora_layers[cls.transformer_name] = transformer_lora_layers
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
if not lora_layers:
|
||||
raise ValueError(
|
||||
"You must pass at least one of `transformer_lora_layers`"
|
||||
)
|
||||
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
|
||||
Args:
|
||||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
Adapter names to be used for fusing.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from diffusers import Kandinsky5T2VPipeline
|
||||
|
||||
pipeline = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V")
|
||||
pipeline.load_lora_weights("path/to/lora.safetensors")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of [`pipe.fuse_lora()`].
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
@@ -4802,4 +5088,4 @@ class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
|
||||
deprecate("LoraLoaderMixin", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -101,6 +101,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
|
||||
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
|
||||
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
|
||||
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
|
||||
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
|
||||
@@ -200,6 +201,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
TransformerTemporalModel,
|
||||
WanTransformer3DModel,
|
||||
WanVACETransformer3DModel,
|
||||
Kandinsky5Transformer3DModel,
|
||||
)
|
||||
from .unets import (
|
||||
I2VGenXLUNet,
|
||||
|
||||
@@ -37,3 +37,4 @@ if is_torch_available():
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
from .transformer_wan import WanTransformer3DModel
|
||||
from .transformer_wan_vace import WanVACETransformer3DModel
|
||||
from .transformer_kandinsky import Kandinsky5Transformer3DModel
|
||||
|
||||
630
src/diffusers/models/transformers/transformer_kandinsky.py
Normal file
630
src/diffusers/models/transformers/transformer_kandinsky.py
Normal file
@@ -0,0 +1,630 @@
|
||||
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple, Union, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import FP32LayerNorm
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# @torch.compile()
|
||||
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
||||
def apply_scale_shift_norm(norm, x, scale, shift):
|
||||
return (norm(x) * (scale + 1.0) + shift).to(torch.bfloat16)
|
||||
|
||||
# @torch.compile()
|
||||
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
||||
def apply_gate_sum(x, out, gate):
|
||||
return (x + gate * out).to(torch.bfloat16)
|
||||
|
||||
# @torch.compile()
|
||||
@torch.autocast(device_type="cuda", enabled=False)
|
||||
def apply_rotary(x, rope):
|
||||
x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32)
|
||||
x_out = (rope * x_).sum(dim=-1)
|
||||
return x_out.reshape(*x.shape).to(torch.bfloat16)
|
||||
|
||||
|
||||
@torch.autocast(device_type="cuda", enabled=False)
|
||||
def get_freqs(dim, max_period=10000.0):
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=dim, dtype=torch.float32)
|
||||
/ dim
|
||||
)
|
||||
return freqs
|
||||
|
||||
|
||||
class TimeEmbeddings(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, max_period=10000.0):
|
||||
super().__init__()
|
||||
assert model_dim % 2 == 0
|
||||
self.model_dim = model_dim
|
||||
self.max_period = max_period
|
||||
self.register_buffer(
|
||||
"freqs", get_freqs(model_dim // 2, max_period), persistent=False
|
||||
)
|
||||
self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
|
||||
self.activation = nn.SiLU()
|
||||
self.out_layer = nn.Linear(time_dim, time_dim, bias=True)
|
||||
|
||||
def forward(self, time):
|
||||
args = torch.outer(time, self.freqs.to(device=time.device))
|
||||
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
|
||||
return time_embed
|
||||
|
||||
|
||||
class TextEmbeddings(nn.Module):
|
||||
def __init__(self, text_dim, model_dim):
|
||||
super().__init__()
|
||||
self.in_layer = nn.Linear(text_dim, model_dim, bias=True)
|
||||
self.norm = nn.LayerNorm(model_dim, elementwise_affine=True)
|
||||
|
||||
def forward(self, text_embed):
|
||||
text_embed = self.in_layer(text_embed)
|
||||
return self.norm(text_embed).type_as(text_embed)
|
||||
|
||||
|
||||
class VisualEmbeddings(nn.Module):
|
||||
def __init__(self, visual_dim, model_dim, patch_size):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.in_layer = nn.Linear(math.prod(patch_size) * visual_dim, model_dim)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, duration, height, width, dim = x.shape
|
||||
x = (
|
||||
x.view(
|
||||
batch_size,
|
||||
duration // self.patch_size[0],
|
||||
self.patch_size[0],
|
||||
height // self.patch_size[1],
|
||||
self.patch_size[1],
|
||||
width // self.patch_size[2],
|
||||
self.patch_size[2],
|
||||
dim,
|
||||
)
|
||||
.permute(0, 1, 3, 5, 2, 4, 6, 7)
|
||||
.flatten(4, 7)
|
||||
)
|
||||
return self.in_layer(x)
|
||||
|
||||
|
||||
class RoPE1D(nn.Module):
|
||||
"""
|
||||
1D Rotary Positional Embeddings for text sequences.
|
||||
|
||||
Args:
|
||||
dim: Dimension of the rotary embeddings
|
||||
max_pos: Maximum sequence length
|
||||
max_period: Maximum period for sinusoidal embeddings
|
||||
"""
|
||||
|
||||
def __init__(self, dim, max_pos=1024, max_period=10000.0):
|
||||
super().__init__()
|
||||
self.max_period = max_period
|
||||
self.dim = dim
|
||||
self.max_pos = max_pos
|
||||
freq = get_freqs(dim // 2, max_period)
|
||||
pos = torch.arange(max_pos, dtype=freq.dtype)
|
||||
self.register_buffer("args", torch.outer(pos, freq), persistent=False)
|
||||
|
||||
def forward(self, pos):
|
||||
"""
|
||||
Args:
|
||||
pos: Position indices of shape [seq_len] or [batch_size, seq_len]
|
||||
|
||||
Returns:
|
||||
Rotary embeddings of shape [seq_len, 1, 2, 2]
|
||||
"""
|
||||
args = self.args[pos]
|
||||
cosine = torch.cos(args)
|
||||
sine = torch.sin(args)
|
||||
rope = torch.stack([cosine, -sine, sine, cosine], dim=-1)
|
||||
rope = rope.view(*rope.shape[:-1], 2, 2)
|
||||
return rope.unsqueeze(-4)
|
||||
|
||||
|
||||
class RoPE3D(nn.Module):
|
||||
def __init__(self, axes_dims, max_pos=(128, 128, 128), max_period=10000.0):
|
||||
super().__init__()
|
||||
self.axes_dims = axes_dims
|
||||
self.max_pos = max_pos
|
||||
self.max_period = max_period
|
||||
|
||||
for i, (axes_dim, ax_max_pos) in enumerate(zip(axes_dims, max_pos)):
|
||||
freq = get_freqs(axes_dim // 2, max_period)
|
||||
pos = torch.arange(ax_max_pos, dtype=freq.dtype)
|
||||
self.register_buffer(f"args_{i}", torch.outer(pos, freq), persistent=False)
|
||||
|
||||
@torch.autocast(device_type="cuda", enabled=False)
|
||||
def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)):
|
||||
batch_size, duration, height, width = shape
|
||||
args_t = self.args_0[pos[0]] / scale_factor[0]
|
||||
args_h = self.args_1[pos[1]] / scale_factor[1]
|
||||
args_w = self.args_2[pos[2]] / scale_factor[2]
|
||||
|
||||
# Replicate the original logic with batch dimension
|
||||
args_t_expanded = args_t.view(1, duration, 1, 1, -1).expand(batch_size, -1, height, width, -1)
|
||||
args_h_expanded = args_h.view(1, 1, height, 1, -1).expand(batch_size, duration, -1, width, -1)
|
||||
args_w_expanded = args_w.view(1, 1, 1, width, -1).expand(batch_size, duration, height, -1, -1)
|
||||
|
||||
# Concatenate along the last dimension
|
||||
args = torch.cat([args_t_expanded, args_h_expanded, args_w_expanded], dim=-1) # [B, D, H, W, F]
|
||||
|
||||
cosine = torch.cos(args)
|
||||
sine = torch.sin(args)
|
||||
rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) # [B, D, H, W, F, 4]
|
||||
rope = rope.view(*rope.shape[:-1], 2, 2) # [B, D, H, W, F, 2, 2]
|
||||
return rope.unsqueeze(-4) # [B, D, H, 1, W, F, 2, 2]
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, time_dim, model_dim, num_params):
|
||||
super().__init__()
|
||||
self.activation = nn.SiLU()
|
||||
self.out_layer = nn.Linear(time_dim, num_params * model_dim)
|
||||
self.out_layer.weight.data.zero_()
|
||||
self.out_layer.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
return self.out_layer(self.activation(x))
|
||||
|
||||
|
||||
class MultiheadSelfAttentionEnc(nn.Module):
|
||||
def __init__(self, num_channels, head_dim):
|
||||
super().__init__()
|
||||
assert num_channels % head_dim == 0
|
||||
self.num_heads = num_channels // head_dim
|
||||
|
||||
self.to_query = nn.Linear(num_channels, num_channels, bias=True)
|
||||
self.to_key = nn.Linear(num_channels, num_channels, bias=True)
|
||||
self.to_value = nn.Linear(num_channels, num_channels, bias=True)
|
||||
self.query_norm = nn.RMSNorm(head_dim)
|
||||
self.key_norm = nn.RMSNorm(head_dim)
|
||||
self.out_layer = nn.Linear(num_channels, num_channels, bias=True)
|
||||
|
||||
def forward(self, x, rope):
|
||||
query = self.to_query(x)
|
||||
key = self.to_key(x)
|
||||
value = self.to_value(x)
|
||||
|
||||
shape = query.shape[:-1]
|
||||
query = query.reshape(*shape, self.num_heads, -1)
|
||||
key = key.reshape(*shape, self.num_heads, -1)
|
||||
value = value.reshape(*shape, self.num_heads, -1)
|
||||
|
||||
query = self.query_norm(query.float()).type_as(query)
|
||||
key = self.key_norm(key.float()).type_as(key)
|
||||
|
||||
query = apply_rotary(query, rope).type_as(query)
|
||||
key = apply_rotary(key, rope).type_as(key)
|
||||
|
||||
# Use torch's scaled_dot_product_attention
|
||||
out = F.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
).flatten(-2, -1)
|
||||
|
||||
out = self.out_layer(out)
|
||||
return out
|
||||
|
||||
|
||||
class MultiheadSelfAttentionDec(nn.Module):
|
||||
def __init__(self, num_channels, head_dim):
|
||||
super().__init__()
|
||||
assert num_channels % head_dim == 0
|
||||
self.num_heads = num_channels // head_dim
|
||||
|
||||
self.to_query = nn.Linear(num_channels, num_channels, bias=True)
|
||||
self.to_key = nn.Linear(num_channels, num_channels, bias=True)
|
||||
self.to_value = nn.Linear(num_channels, num_channels, bias=True)
|
||||
self.query_norm = nn.RMSNorm(head_dim)
|
||||
self.key_norm = nn.RMSNorm(head_dim)
|
||||
self.out_layer = nn.Linear(num_channels, num_channels, bias=True)
|
||||
|
||||
def forward(self, x, rope, sparse_params=None):
|
||||
query = self.to_query(x)
|
||||
key = self.to_key(x)
|
||||
value = self.to_value(x)
|
||||
|
||||
shape = query.shape[:-1]
|
||||
query = query.reshape(*shape, self.num_heads, -1)
|
||||
key = key.reshape(*shape, self.num_heads, -1)
|
||||
value = value.reshape(*shape, self.num_heads, -1)
|
||||
|
||||
query = self.query_norm(query.float()).type_as(query)
|
||||
key = self.key_norm(key.float()).type_as(key)
|
||||
|
||||
query = apply_rotary(query, rope).type_as(query)
|
||||
key = apply_rotary(key, rope).type_as(key)
|
||||
|
||||
# Use standard attention (can be extended with sparse attention)
|
||||
out = F.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
).flatten(-2, -1)
|
||||
|
||||
out = self.out_layer(out)
|
||||
return out
|
||||
|
||||
|
||||
class MultiheadCrossAttention(nn.Module):
|
||||
def __init__(self, num_channels, head_dim):
|
||||
super().__init__()
|
||||
assert num_channels % head_dim == 0
|
||||
self.num_heads = num_channels // head_dim
|
||||
|
||||
self.to_query = nn.Linear(num_channels, num_channels, bias=True)
|
||||
self.to_key = nn.Linear(num_channels, num_channels, bias=True)
|
||||
self.to_value = nn.Linear(num_channels, num_channels, bias=True)
|
||||
self.query_norm = nn.RMSNorm(head_dim)
|
||||
self.key_norm = nn.RMSNorm(head_dim)
|
||||
self.out_layer = nn.Linear(num_channels, num_channels, bias=True)
|
||||
|
||||
def forward(self, x, cond):
|
||||
query = self.to_query(x)
|
||||
key = self.to_key(cond)
|
||||
value = self.to_value(cond)
|
||||
|
||||
shape, cond_shape = query.shape[:-1], key.shape[:-1]
|
||||
query = query.reshape(*shape, self.num_heads, -1)
|
||||
key = key.reshape(*cond_shape, self.num_heads, -1)
|
||||
value = value.reshape(*cond_shape, self.num_heads, -1)
|
||||
|
||||
query = self.query_norm(query.float()).type_as(query)
|
||||
key = self.key_norm(key.float()).type_as(key)
|
||||
|
||||
out = F.scaled_dot_product_attention(
|
||||
query.permute(0, 2, 1, 3),
|
||||
key.permute(0, 2, 1, 3),
|
||||
value.permute(0, 2, 1, 3),
|
||||
).permute(0, 2, 1, 3).flatten(-2, -1)
|
||||
|
||||
out = self.out_layer(out)
|
||||
return out
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, ff_dim):
|
||||
super().__init__()
|
||||
self.in_layer = nn.Linear(dim, ff_dim, bias=False)
|
||||
self.activation = nn.GELU()
|
||||
self.out_layer = nn.Linear(ff_dim, dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.out_layer(self.activation(self.in_layer(x)))
|
||||
|
||||
|
||||
class TransformerEncoderBlock(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, ff_dim, head_dim):
|
||||
super().__init__()
|
||||
self.text_modulation = Modulation(time_dim, model_dim, 6)
|
||||
|
||||
self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
|
||||
self.self_attention = MultiheadSelfAttentionEnc(model_dim, head_dim)
|
||||
|
||||
self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
|
||||
self.feed_forward = FeedForward(model_dim, ff_dim)
|
||||
|
||||
def forward(self, x, time_embed, rope):
|
||||
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1)
|
||||
shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
|
||||
|
||||
out = self.self_attention_norm(x)
|
||||
out = out * (scale + 1.0) + shift
|
||||
out = self.self_attention(out, rope)
|
||||
x = x + gate * out
|
||||
|
||||
shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
|
||||
out = self.feed_forward_norm(x)
|
||||
out = out * (scale + 1.0) + shift
|
||||
out = self.feed_forward(out)
|
||||
x = x + gate * out
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoderBlock(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, ff_dim, head_dim):
|
||||
super().__init__()
|
||||
self.visual_modulation = Modulation(time_dim, model_dim, 9)
|
||||
|
||||
self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
|
||||
self.self_attention = MultiheadSelfAttentionDec(model_dim, head_dim)
|
||||
|
||||
self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
|
||||
self.cross_attention = MultiheadCrossAttention(model_dim, head_dim)
|
||||
|
||||
self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
|
||||
self.feed_forward = FeedForward(model_dim, ff_dim)
|
||||
|
||||
def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params):
|
||||
self_attn_params, cross_attn_params, ff_params = torch.chunk(
|
||||
self.visual_modulation(time_embed), 3, dim=-1
|
||||
)
|
||||
shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
|
||||
|
||||
visual_out = self.self_attention_norm(visual_embed)
|
||||
visual_out = visual_out * (scale + 1.0) + shift
|
||||
visual_out = self.self_attention(visual_out, rope, sparse_params)
|
||||
visual_embed = visual_embed + gate * visual_out
|
||||
|
||||
shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1)
|
||||
visual_out = self.cross_attention_norm(visual_embed)
|
||||
visual_out = visual_out * (scale + 1.0) + shift
|
||||
visual_out = self.cross_attention(visual_out, text_embed)
|
||||
visual_embed = visual_embed + gate * visual_out
|
||||
|
||||
shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
|
||||
visual_out = self.feed_forward_norm(visual_embed)
|
||||
visual_out = visual_out * (scale + 1.0) + shift
|
||||
visual_out = self.feed_forward(visual_out)
|
||||
visual_embed = visual_embed + gate * visual_out
|
||||
return visual_embed
|
||||
|
||||
|
||||
class OutLayer(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, visual_dim, patch_size):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.modulation = Modulation(time_dim, model_dim, 2)
|
||||
self.norm = nn.LayerNorm(model_dim, elementwise_affine=False)
|
||||
self.out_layer = nn.Linear(
|
||||
model_dim, math.prod(patch_size) * visual_dim, bias=True
|
||||
)
|
||||
|
||||
def forward(self, visual_embed, text_embed, time_embed):
|
||||
# Handle the new batch dimension: [batch, duration, height, width, model_dim]
|
||||
batch_size, duration, height, width, _ = visual_embed.shape
|
||||
|
||||
shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1)
|
||||
|
||||
# Apply modulation with proper broadcasting for the new shape
|
||||
visual_embed = apply_scale_shift_norm(
|
||||
self.norm,
|
||||
visual_embed,
|
||||
scale[:, None, None, None], # [batch, 1, 1, 1, model_dim] -> [batch, 1, 1, 1]
|
||||
shift[:, None, None, None], # [batch, 1, 1, 1, model_dim] -> [batch, 1, 1, 1]
|
||||
).type_as(visual_embed)
|
||||
|
||||
x = self.out_layer(visual_embed)
|
||||
|
||||
# Now x has shape [batch, duration, height, width, patch_prod * visual_dim]
|
||||
x = (
|
||||
x.view(
|
||||
batch_size,
|
||||
duration,
|
||||
height,
|
||||
width,
|
||||
-1,
|
||||
self.patch_size[0],
|
||||
self.patch_size[1],
|
||||
self.patch_size[2],
|
||||
)
|
||||
.permute(0, 5, 1, 6, 2, 7, 3, 4) # [batch, patch_t, duration, patch_h, height, patch_w, width, features]
|
||||
.flatten(1, 2) # [batch, patch_t * duration, height, patch_w, width, features]
|
||||
.flatten(2, 3) # [batch, patch_t * duration, patch_h * height, width, features]
|
||||
.flatten(3, 4) # [batch, patch_t * duration, patch_h * height, patch_w * width]
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A 3D Transformer model for video generation used in Kandinsky 5.0.
|
||||
|
||||
This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
||||
methods implemented for all models (such as downloading or saving).
|
||||
|
||||
Args:
|
||||
in_visual_dim (`int`, defaults to 16):
|
||||
Number of channels in the input visual latent.
|
||||
out_visual_dim (`int`, defaults to 16):
|
||||
Number of channels in the output visual latent.
|
||||
time_dim (`int`, defaults to 512):
|
||||
Dimension of the time embeddings.
|
||||
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
|
||||
Patch size for the visual embeddings (temporal, height, width).
|
||||
model_dim (`int`, defaults to 1792):
|
||||
Hidden dimension of the transformer model.
|
||||
ff_dim (`int`, defaults to 7168):
|
||||
Intermediate dimension of the feed-forward networks.
|
||||
num_text_blocks (`int`, defaults to 2):
|
||||
Number of transformer blocks in the text encoder.
|
||||
num_visual_blocks (`int`, defaults to 32):
|
||||
Number of transformer blocks in the visual decoder.
|
||||
axes_dims (`Tuple[int]`, defaults to `(16, 24, 24)`):
|
||||
Dimensions for the rotary positional embeddings (temporal, height, width).
|
||||
visual_cond (`bool`, defaults to `True`):
|
||||
Whether to use visual conditioning (for image/video conditioning).
|
||||
in_text_dim (`int`, defaults to 3584):
|
||||
Dimension of the text embeddings from Qwen2.5-VL.
|
||||
in_text_dim2 (`int`, defaults to 768):
|
||||
Dimension of the pooled text embeddings from CLIP.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_visual_dim: int = 16,
|
||||
out_visual_dim: int = 16,
|
||||
time_dim: int = 512,
|
||||
patch_size: Tuple[int, int, int] = (1, 2, 2),
|
||||
model_dim: int = 1792,
|
||||
ff_dim: int = 7168,
|
||||
num_text_blocks: int = 2,
|
||||
num_visual_blocks: int = 32,
|
||||
axes_dims: Tuple[int, int, int] = (16, 24, 24),
|
||||
visual_cond: bool = True,
|
||||
in_text_dim: int = 3584,
|
||||
in_text_dim2: int = 768,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_visual_dim = in_visual_dim
|
||||
self.model_dim = model_dim
|
||||
self.patch_size = patch_size
|
||||
self.visual_cond = visual_cond
|
||||
|
||||
# Calculate head dimension for attention
|
||||
head_dim = sum(axes_dims)
|
||||
|
||||
# Determine visual embedding dimension based on conditioning
|
||||
visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim
|
||||
|
||||
# 1. Embedding layers
|
||||
self.time_embeddings = TimeEmbeddings(model_dim, time_dim)
|
||||
self.text_embeddings = TextEmbeddings(in_text_dim, model_dim)
|
||||
self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim)
|
||||
self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size)
|
||||
|
||||
# 2. Rotary positional embeddings
|
||||
self.text_rope_embeddings = RoPE1D(head_dim)
|
||||
self.visual_rope_embeddings = RoPE3D(axes_dims)
|
||||
|
||||
# 3. Transformer blocks
|
||||
self.text_transformer_blocks = nn.ModuleList([
|
||||
TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim)
|
||||
for _ in range(num_text_blocks)
|
||||
])
|
||||
|
||||
self.visual_transformer_blocks = nn.ModuleList([
|
||||
TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim)
|
||||
for _ in range(num_visual_blocks)
|
||||
])
|
||||
|
||||
# 4. Output layer
|
||||
self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
pooled_text_embed: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
visual_rope_pos: List[torch.Tensor],
|
||||
text_rope_pos: torch.Tensor,
|
||||
scale_factor: Tuple[float, float, float] = (1.0, 1.0, 1.0),
|
||||
sparse_params: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
Forward pass of the Kandinsky 5.0 3D Transformer.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor`):
|
||||
Input visual latent tensor of shape `(batch_size, num_frames, height, width, channels)`.
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
Text embeddings from Qwen2.5-VL of shape `(batch_size, sequence_length, text_dim)`.
|
||||
pooled_text_embed (`torch.Tensor`):
|
||||
Pooled text embeddings from CLIP of shape `(batch_size, pooled_text_dim)`.
|
||||
timestep (`torch.Tensor`):
|
||||
Timestep tensor of shape `(batch_size,)` or `(batch_size * num_frames,)`.
|
||||
visual_rope_pos (`List[torch.Tensor]`):
|
||||
List of tensors for visual rotary positional embeddings [temporal, height, width].
|
||||
text_rope_pos (`torch.Tensor`):
|
||||
Tensor for text rotary positional embeddings.
|
||||
scale_factor (`Tuple[float, float, float]`, defaults to `(1.0, 1.0, 1.0)`):
|
||||
Scale factors for rotary positional embeddings.
|
||||
sparse_params (`Dict[str, Any]`, *optional*):
|
||||
Parameters for sparse attention.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether to return a dictionary or a tensor.
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, a [`~models.transformer_2d.Transformer2DModelOutput`] is returned,
|
||||
otherwise a `tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
batch_size, num_frames, height, width, channels = hidden_states.shape
|
||||
|
||||
# 1. Process text embeddings
|
||||
text_embed = self.text_embeddings(encoder_hidden_states)
|
||||
time_embed = self.time_embeddings(timestep)
|
||||
|
||||
# Add pooled text embedding to time embedding
|
||||
pooled_embed = self.pooled_text_embeddings(pooled_text_embed)
|
||||
time_embed = time_embed + pooled_embed
|
||||
|
||||
# visual_embed shape: [batch_size, seq_len, model_dim]
|
||||
visual_embed = self.visual_embeddings(hidden_states)
|
||||
|
||||
# 3. Text rotary embeddings
|
||||
text_rope = self.text_rope_embeddings(text_rope_pos)
|
||||
|
||||
# 4. Text transformer blocks
|
||||
for text_block in self.text_transformer_blocks:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
text_embed = torch.utils.checkpoint.checkpoint(
|
||||
text_block, text_embed, time_embed, text_rope, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
text_embed = text_block(text_embed, time_embed, text_rope)
|
||||
|
||||
# 5. Prepare visual rope
|
||||
visual_shape = visual_embed.shape[:-1]
|
||||
visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor)
|
||||
|
||||
visual_embed = visual_embed.reshape(visual_embed.shape[0], -1, visual_embed.shape[-1])
|
||||
visual_rope = visual_rope.view(visual_rope.shape[0], -1, *list(visual_rope.shape[-4:]))
|
||||
|
||||
# 6. Visual transformer blocks
|
||||
for visual_block in self.visual_transformer_blocks:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
visual_embed = torch.utils.checkpoint.checkpoint(
|
||||
visual_block,
|
||||
visual_embed,
|
||||
text_embed,
|
||||
time_embed,
|
||||
visual_rope,
|
||||
# visual_rope_flat,
|
||||
sparse_params,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
visual_embed = visual_block(
|
||||
visual_embed, text_embed, time_embed, visual_rope, sparse_params
|
||||
)
|
||||
|
||||
# 7. Output projection
|
||||
visual_embed = visual_embed.reshape(batch_size, num_frames, height // 2, width // 2, -1)
|
||||
output = self.out_layer(visual_embed, text_embed, time_embed)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -382,6 +382,7 @@ else:
|
||||
"WuerstchenPriorPipeline",
|
||||
]
|
||||
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline", "WanVACEPipeline"]
|
||||
_import_structure["kandinsky5"] = ["Kandinsky5T2VPipeline"]
|
||||
_import_structure["skyreels_v2"] = [
|
||||
"SkyReelsV2DiffusionForcingPipeline",
|
||||
"SkyReelsV2DiffusionForcingImageToVideoPipeline",
|
||||
@@ -787,6 +788,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
)
|
||||
from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline
|
||||
from .wan import WanImageToVideoPipeline, WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline
|
||||
from .kandinsky5 import Kandinsky5T2VPipeline
|
||||
from .wuerstchen import (
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
|
||||
48
src/diffusers/pipelines/kandinsky5/__init__.py
Normal file
48
src/diffusers/pipelines/kandinsky5/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_kandinsky"] = ["Kandinsky5T2VPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_kandinsky import Kandinsky5T2VPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
545
src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py
Normal file
545
src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py
Normal file
@@ -0,0 +1,545 @@
|
||||
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import html
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from transformers import Qwen2TokenizerFast, Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, AutoProcessor, CLIPTextModel, CLIPTokenizer
|
||||
import torchvision
|
||||
from torchvision.transforms import ToPILImage
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...loaders import KandinskyLoraLoaderMixin
|
||||
from ...models import AutoencoderKLHunyuanVideo
|
||||
from ...models.transformers import Kandinsky5Transformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import KandinskyPipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import Kandinsky5T2VPipeline, Kandinsky5Transformer3DModel
|
||||
>>> from diffusers.utils import export_to_video
|
||||
|
||||
>>> pipe = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V")
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> prompt = "A cat and a dog baking a cake together in a kitchen."
|
||||
>>> negative_prompt = "Bright tones, overexposed, static, blurred details"
|
||||
|
||||
>>> output = pipe(
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... height=512,
|
||||
... width=768,
|
||||
... num_frames=25,
|
||||
... num_inference_steps=50,
|
||||
... guidance_scale=5.0,
|
||||
... ).frames[0]
|
||||
>>> export_to_video(output, "output.mp4", fps=6)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using Kandinsky 5.0.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
transformer ([`Kandinsky5Transformer3DModel`]):
|
||||
Conditional Transformer to denoise the encoded video latents.
|
||||
vae ([`AutoencoderKLHunyuanVideo`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
text_encoder ([`Qwen2_5_VLForConditionalGeneration`]):
|
||||
Frozen text-encoder (Qwen2.5-VL).
|
||||
tokenizer ([`AutoProcessor`]):
|
||||
Tokenizer for Qwen2.5-VL.
|
||||
text_encoder_2 ([`CLIPTextModel`]):
|
||||
Frozen CLIP text encoder.
|
||||
tokenizer_2 ([`CLIPTokenizer`]):
|
||||
Tokenizer for CLIP.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: Kandinsky5Transformer3DModel,
|
||||
vae: AutoencoderKLHunyuanVideo,
|
||||
text_encoder: Qwen2_5_VLForConditionalGeneration,
|
||||
tokenizer: Qwen2VLProcessor,
|
||||
text_encoder_2: CLIPTextModel,
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio
|
||||
self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio
|
||||
|
||||
def _encode_prompt_qwen(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 256,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
# Kandinsky specific prompt template
|
||||
prompt_template = "\n".join([
|
||||
"<|im_start|>system\nYou are a promt engineer. Describe the video in detail.",
|
||||
"Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.",
|
||||
"Describe the location of the video, main characters or objects and their action.",
|
||||
"Describe the dynamism of the video and presented actions.",
|
||||
"Name the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.",
|
||||
"Describe the visual effects, postprocessing and transitions if they are presented in the video.",
|
||||
"Pay attention to the order of key actions shown in the scene.<|im_end|>",
|
||||
"<|im_start|>user\n{}<|im_end|>",
|
||||
])
|
||||
crop_start = 129
|
||||
|
||||
full_texts = [prompt_template.format(p) for p in prompt]
|
||||
|
||||
inputs = self.tokenizer(
|
||||
text=full_texts,
|
||||
images=None,
|
||||
videos=None,
|
||||
max_length=max_sequence_length + crop_start,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
embeds = self.text_encoder(
|
||||
input_ids=inputs["input_ids"],
|
||||
return_dict=True,
|
||||
output_hidden_states=True,
|
||||
)["hidden_states"][-1][:, crop_start:]
|
||||
|
||||
attention_mask = inputs["attention_mask"][:, crop_start:]
|
||||
embeds = embeds[attention_mask.bool()]
|
||||
cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0)
|
||||
cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to(dtype=torch.int32)
|
||||
|
||||
# duplicate for each generation per prompt
|
||||
batch_size = len(prompt)
|
||||
seq_len = embeds.shape[0] // batch_size
|
||||
embeds = embeds.reshape(batch_size, seq_len, -1)
|
||||
embeds = embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
embeds = embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return embeds, cu_seqlens
|
||||
|
||||
def _encode_prompt_clip(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
inputs = self.tokenizer_2(
|
||||
prompt,
|
||||
max_length=77,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
pooled_embed = self.text_encoder_2(**inputs)["pooler_output"]
|
||||
|
||||
# duplicate for each generation per prompt
|
||||
batch_size = len(prompt)
|
||||
pooled_embed = pooled_embed.repeat(1, num_videos_per_prompt, 1)
|
||||
pooled_embed = pooled_embed.view(batch_size * num_videos_per_prompt, -1)
|
||||
|
||||
return pooled_embed
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
|
||||
# Encode with Qwen2.5-VL
|
||||
prompt_embeds, prompt_cu_seqlens = self._encode_prompt_qwen(
|
||||
prompt, device, num_videos_per_prompt
|
||||
)
|
||||
pooled_embed = self._encode_prompt_clip(prompt, device, num_videos_per_prompt)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
negative_prompt_embeds, negative_cu_seqlens = self._encode_prompt_qwen(
|
||||
negative_prompt, device, num_videos_per_prompt
|
||||
)
|
||||
negative_pooled_embed = self._encode_prompt_clip(negative_prompt, device, num_videos_per_prompt)
|
||||
else:
|
||||
negative_prompt_embeds = None
|
||||
negative_pooled_embed = None
|
||||
negative_cu_seqlens = None
|
||||
|
||||
text_embeds = {
|
||||
"text_embeds": prompt_embeds,
|
||||
"pooled_embed": pooled_embed,
|
||||
}
|
||||
negative_text_embeds = {
|
||||
"text_embeds": negative_prompt_embeds,
|
||||
"pooled_embed": negative_pooled_embed,
|
||||
} if do_classifier_free_guidance else None
|
||||
|
||||
return text_embeds, negative_text_embeds, prompt_cu_seqlens, negative_cu_seqlens
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: int = 16,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
num_frames: int = 81,
|
||||
visual_cond: bool = False,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
shape = (
|
||||
batch_size,
|
||||
num_latent_frames,
|
||||
int(height) // self.vae_scale_factor_spatial,
|
||||
int(width) // self.vae_scale_factor_spatial,
|
||||
num_channels_latents,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
if visual_cond:
|
||||
# For visual conditioning, concatenate with zeros and mask
|
||||
visual_cond = torch.zeros_like(latents)
|
||||
visual_cond_mask = torch.zeros(
|
||||
[batch_size, num_latent_frames, int(height) // self.vae_scale_factor_spatial, int(width) // self.vae_scale_factor_spatial, 1],
|
||||
dtype=latents.dtype,
|
||||
device=latents.device
|
||||
)
|
||||
latents = torch.cat([latents, visual_cond, visual_cond_mask], dim=-1)
|
||||
|
||||
return latents
|
||||
|
||||
def get_velocity(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
text_embeds: Dict[str, torch.Tensor],
|
||||
negative_text_embeds: Optional[Dict[str, torch.Tensor]],
|
||||
visual_rope_pos: List[torch.Tensor],
|
||||
text_rope_pos: torch.Tensor,
|
||||
negative_text_rope_pos: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
sparse_params: Optional[Dict] = None,
|
||||
):
|
||||
# print(latents.shape, text_embeds["text_embeds"].shape, text_embeds["pooled_embed"].shape, timestep.shape, [el.shape for el in visual_rope_pos], text_rope_pos, sparse_params)
|
||||
|
||||
pred_velocity = self.transformer(
|
||||
latents,
|
||||
text_embeds["text_embeds"],
|
||||
text_embeds["pooled_embed"],
|
||||
timestep * 1000, # Scale to match training
|
||||
visual_rope_pos,
|
||||
text_rope_pos,
|
||||
scale_factor=(1, 2, 2), # From Kandinsky config
|
||||
sparse_params=sparse_params,
|
||||
return_dict=False
|
||||
)[0]
|
||||
|
||||
if guidance_scale > 1.0 and negative_text_embeds is not None:
|
||||
uncond_pred_velocity = self.transformer(
|
||||
latents,
|
||||
negative_text_embeds["text_embeds"],
|
||||
negative_text_embeds["pooled_embed"],
|
||||
timestep * 1000,
|
||||
visual_rope_pos,
|
||||
negative_text_rope_pos,
|
||||
scale_factor=(1, 2, 2),
|
||||
sparse_params=sparse_params,
|
||||
return_dict=False
|
||||
)[0]
|
||||
|
||||
pred_velocity = uncond_pred_velocity + guidance_scale * (
|
||||
pred_velocity - uncond_pred_velocity
|
||||
)
|
||||
|
||||
return pred_velocity
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 512,
|
||||
width: int = 768,
|
||||
num_frames: int = 25,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
scheduler_scale: float = 10.0,
|
||||
num_videos_per_prompt: int = 1,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the video generation.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to avoid during video generation.
|
||||
height (`int`, defaults to `512`):
|
||||
The height in pixels of the generated video.
|
||||
width (`int`, defaults to `768`):
|
||||
The width in pixels of the generated video.
|
||||
num_frames (`int`, defaults to `25`):
|
||||
The number of frames in the generated video.
|
||||
num_inference_steps (`int`, defaults to `50`):
|
||||
The number of denoising steps.
|
||||
guidance_scale (`float`, defaults to `5.0`):
|
||||
Guidance scale as defined in classifier-free guidance.
|
||||
scheduler_scale (`float`, defaults to `10.0`):
|
||||
Scale factor for the custom flow matching scheduler.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of videos to generate per prompt.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A torch generator to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated video.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`KandinskyPipelineOutput`].
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that is called at the end of each denoising step.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~KandinskyPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`KandinskyPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
||||
the first element is a list with the generated images and the second element is a list of `bool`s
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
# 1. Check inputs
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
# 2. Define call parameters
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
else:
|
||||
batch_size = len(prompt)
|
||||
|
||||
device = self._execution_device
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
|
||||
if num_frames % self.vae_scale_factor_temporal != 1:
|
||||
logger.warning(
|
||||
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
|
||||
)
|
||||
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
||||
num_frames = max(num_frames, 1)
|
||||
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_embeds, negative_text_embeds, prompt_cu_seqlens, negative_cu_seqlens = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps (Kandinsky uses custom flow matching)
|
||||
timesteps = torch.linspace(1, 0, num_inference_steps + 1, device=device)
|
||||
timesteps = scheduler_scale * timesteps / (1 + (scheduler_scale - 1) * timesteps)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = 16
|
||||
latents = self.prepare_latents(
|
||||
batch_size=batch_size * num_videos_per_prompt,
|
||||
num_channels_latents=16,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
visual_cond=self.transformer.visual_cond,
|
||||
dtype=self.transformer.dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
# 6. Prepare rope positions
|
||||
visual_rope_pos = [
|
||||
torch.arange(num_frames // 4 + 1, device=device),
|
||||
torch.arange(height // 8 // 2, device=device), # patch size 2
|
||||
torch.arange(width // 8 // 2, device=device),
|
||||
]
|
||||
|
||||
text_rope_pos = torch.arange(prompt_cu_seqlens[-1].item(), device=device)
|
||||
|
||||
negative_text_rope_pos = (
|
||||
torch.arange(negative_cu_seqlens[-1].item(), device=device)
|
||||
if negative_cu_seqlens is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# 7. Prepare sparse attention params if needed
|
||||
sparse_params = None # Can be extended based on Kandinsky attention config
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, (timestep, timestep_diff) in enumerate(zip(timesteps[:-1], torch.diff(timesteps))):
|
||||
# Expand timestep to match batch size
|
||||
time = timestep.unsqueeze(0)
|
||||
|
||||
pred_velocity = self.get_velocity(
|
||||
latents,
|
||||
time,
|
||||
text_embeds,
|
||||
negative_text_embeds,
|
||||
visual_rope_pos,
|
||||
text_rope_pos,
|
||||
negative_text_rope_pos,
|
||||
guidance_scale,
|
||||
sparse_params,
|
||||
)
|
||||
|
||||
# Update latents using flow matching
|
||||
latents[:, :, :, :, :16] = latents[:, :, :, :, :16] + timestep_diff * pred_velocity
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, timestep, callback_kwargs)
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % 1 == 0):
|
||||
progress_bar.update()
|
||||
|
||||
latents = latents[:, :, :, :, :16]
|
||||
|
||||
# 9. Decode latents to video
|
||||
if output_type != "latent":
|
||||
latents = latents.to(self.vae.dtype)
|
||||
# Reshape and normalize latents
|
||||
video = latents.reshape(
|
||||
batch_size,
|
||||
num_videos_per_prompt,
|
||||
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
||||
height // 8,
|
||||
width // 8,
|
||||
16,
|
||||
)
|
||||
video = video.permute(0, 1, 5, 2, 3, 4) # [batch, num_videos, channels, frames, height, width]
|
||||
video = video.reshape(batch_size * num_videos_per_prompt, 16, (num_frames - 1) // self.vae_scale_factor_temporal + 1, height // 8, width // 8)
|
||||
|
||||
# Normalize and decode
|
||||
video = video / self.vae.config.scaling_factor
|
||||
video = self.vae.decode(video).sample
|
||||
video = ((video.clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8)
|
||||
|
||||
# Convert to output format
|
||||
if output_type == "pil":
|
||||
if num_frames == 1:
|
||||
# Single image
|
||||
video = [ToPILImage()(frame.squeeze(1)) for frame in video]
|
||||
else:
|
||||
# Video frames
|
||||
video = [video[i] for i in range(video.shape[0])]
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return KandinskyPipelineOutput(frames=video)
|
||||
20
src/diffusers/pipelines/kandinsky5/pipeline_output.py
Normal file
20
src/diffusers/pipelines/kandinsky5/pipeline_output.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class KandinskyPipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for Wan pipelines.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
||||
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`.
|
||||
"""
|
||||
|
||||
frames: torch.Tensor
|
||||
Reference in New Issue
Block a user