1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix 5sec generation

This commit is contained in:
leffff
2025-10-09 15:09:50 +00:00
parent 7db6093c53
commit a0cf07f7e0
2 changed files with 369 additions and 344 deletions

View File

@@ -13,21 +13,27 @@
# limitations under the License.
import math
from typing import Any, Dict, Optional, Tuple, Union, List
from typing import Any, Dict, List, Optional, Tuple, Union
from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import BoolTensor, IntTensor, Tensor, nn
from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
flex_attention)
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import (USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers,
unscale_lora_layers)
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
from ..embeddings import (PixArtAlphaTextProjection, TimestepEmbedding,
Timesteps, get_1d_rotary_pos_embed)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import FP32LayerNorm
@@ -35,39 +41,14 @@ from ..normalization import FP32LayerNorm
logger = logging.get_logger(__name__)
if torch.cuda.get_device_capability()[0] >= 9:
try:
from flash_attn_interface import flash_attn_func as FA
except:
FA = None
try:
from flash_attn import flash_attn_func as FA
except:
FA = None
else:
try:
from flash_attn import flash_attn_func as FA
except:
FA = None
def exist(item):
return item is not None
# @torch.compile()
@torch.autocast(device_type="cuda", dtype=torch.float32)
def apply_scale_shift_norm(norm, x, scale, shift):
return (norm(x) * (scale + 1.0) + shift).to(torch.bfloat16)
# @torch.compile()
@torch.autocast(device_type="cuda", dtype=torch.float32)
def apply_gate_sum(x, out, gate):
return (x + gate * out).to(torch.bfloat16)
# @torch.compile()
@torch.autocast(device_type="cuda", enabled=False)
def apply_rotary(x, rope):
x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32)
x_out = (rope * x_).sum(dim=-1)
return x_out.reshape(*x.shape).to(torch.bfloat16)
def freeze(model):
for p in model.parameters():
p.requires_grad = False
return model
@torch.autocast(device_type="cuda", enabled=False)
@@ -80,6 +61,116 @@ def get_freqs(dim, max_period=10000.0):
return freqs
def fractal_flatten(x, rope, shape, block_mask=False):
if block_mask:
pixel_size = 8
x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=0)
rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=0)
x = x.flatten(1, 2)
rope = rope.flatten(1, 2)
else:
x = x.flatten(1, 3)
rope = rope.flatten(1, 3)
return x, rope
def fractal_unflatten(x, shape, block_mask=False):
if block_mask:
pixel_size = 8
x = x.reshape(-1, pixel_size**2, *x.shape[1:])
x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=0)
else:
x = x.reshape(*shape, *x.shape[2:])
return x
def local_patching(x, shape, group_size, dim=0):
duration, height, width = shape
g1, g2, g3 = group_size
x = x.reshape(
*x.shape[:dim],
duration // g1,
g1,
height // g2,
g2,
width // g3,
g3,
*x.shape[dim + 3 :]
)
x = x.permute(
*range(len(x.shape[:dim])),
dim,
dim + 2,
dim + 4,
dim + 1,
dim + 3,
dim + 5,
*range(dim + 6, len(x.shape))
)
x = x.flatten(dim, dim + 2).flatten(dim + 1, dim + 3)
return x
def local_merge(x, shape, group_size, dim=0):
duration, height, width = shape
g1, g2, g3 = group_size
x = x.reshape(
*x.shape[:dim],
duration // g1,
height // g2,
width // g3,
g1,
g2,
g3,
*x.shape[dim + 2 :]
)
x = x.permute(
*range(len(x.shape[:dim])),
dim,
dim + 3,
dim + 1,
dim + 4,
dim + 2,
dim + 5,
*range(dim + 6, len(x.shape))
)
x = x.flatten(dim, dim + 1).flatten(dim + 1, dim + 2).flatten(dim + 2, dim + 3)
return x
def sdpa(q, k, v):
query = q.transpose(1, 2).contiguous()
key = k.transpose(1, 2).contiguous()
value = v.transpose(1, 2).contiguous()
out = (
F.scaled_dot_product_attention(
query,
key,
value
)
.transpose(1, 2)
.contiguous()
)
return out
@torch.autocast(device_type="cuda", dtype=torch.float32)
def apply_scale_shift_norm(norm, x, scale, shift):
return (norm(x) * (scale + 1.0) + shift).to(torch.bfloat16)
@torch.autocast(device_type="cuda", dtype=torch.float32)
def apply_gate_sum(x, out, gate):
return (x + gate * out).to(torch.bfloat16)
@torch.autocast(device_type="cuda", enabled=False)
def apply_rotary(x, rope):
x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32)
x_out = (rope * x_).sum(dim=-1)
return x_out.reshape(*x.shape).to(torch.bfloat16)
class TimeEmbeddings(nn.Module):
def __init__(self, model_dim, time_dim, max_period=10000.0):
super().__init__()
@@ -93,12 +184,16 @@ class TimeEmbeddings(nn.Module):
self.activation = nn.SiLU()
self.out_layer = nn.Linear(time_dim, time_dim, bias=True)
@torch.autocast(device_type="cuda", dtype=torch.float32)
def forward(self, time):
args = torch.outer(time, self.freqs.to(device=time.device))
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
return time_embed
def reset_dtype(self):
self.freqs = get_freqs(self.model_dim // 2, self.max_period)
class TextEmbeddings(nn.Module):
def __init__(self, text_dim, model_dim):
@@ -116,7 +211,7 @@ class VisualEmbeddings(nn.Module):
super().__init__()
self.patch_size = patch_size
self.in_layer = nn.Linear(math.prod(patch_size) * visual_dim, model_dim)
def forward(self, x):
batch_size, duration, height, width, dim = x.shape
x = (
@@ -124,7 +219,7 @@ class VisualEmbeddings(nn.Module):
batch_size,
duration // self.patch_size[0],
self.patch_size[0],
height // self.patch_size[1],
height // self.patch_size[1],
self.patch_size[1],
width // self.patch_size[2],
self.patch_size[2],
@@ -137,15 +232,6 @@ class VisualEmbeddings(nn.Module):
class RoPE1D(nn.Module):
"""
1D Rotary Positional Embeddings for text sequences.
Args:
dim: Dimension of the rotary embeddings
max_pos: Maximum sequence length
max_period: Maximum period for sinusoidal embeddings
"""
def __init__(self, dim, max_pos=1024, max_period=10000.0):
super().__init__()
self.max_period = max_period
@@ -153,22 +239,21 @@ class RoPE1D(nn.Module):
self.max_pos = max_pos
freq = get_freqs(dim // 2, max_period)
pos = torch.arange(max_pos, dtype=freq.dtype)
self.register_buffer("args", torch.outer(pos, freq), persistent=False)
self.register_buffer(f"args", torch.outer(pos, freq), persistent=False)
@torch.autocast(device_type="cuda", enabled=False)
def forward(self, pos):
"""
Args:
pos: Position indices of shape [seq_len] or [batch_size, seq_len]
Returns:
Rotary embeddings of shape [seq_len, 1, 2, 2]
"""
args = self.args[pos]
cosine = torch.cos(args)
sine = torch.sin(args)
rope = torch.stack([cosine, -sine, sine, cosine], dim=-1)
rope = rope.view(*rope.shape[:-1], 2, 2)
return rope.unsqueeze(-4)
def reset_dtype(self):
freq = get_freqs(self.dim // 2, self.max_period).to(self.args.device)
pos = torch.arange(self.max_pos, dtype=freq.dtype, device=freq.device)
self.args = torch.outer(pos, freq)
class RoPE3D(nn.Module):
@@ -186,22 +271,29 @@ class RoPE3D(nn.Module):
@torch.autocast(device_type="cuda", enabled=False)
def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)):
batch_size, duration, height, width = shape
args_t = self.args_0[pos[0]] / scale_factor[0]
args_h = self.args_1[pos[1]] / scale_factor[1]
args_w = self.args_2[pos[2]] / scale_factor[2]
args_t_expanded = args_t.view(1, duration, 1, 1, -1).expand(batch_size, -1, height, width, -1)
args_h_expanded = args_h.view(1, 1, height, 1, -1).expand(batch_size, duration, -1, width, -1)
args_w_expanded = args_w.view(1, 1, 1, width, -1).expand(batch_size, duration, height, -1, -1)
args = torch.cat([args_t_expanded, args_h_expanded, args_w_expanded], dim=-1)
args = torch.cat(
[
args_t.view(1, duration, 1, 1, -1).repeat(batch_size, 1, height, width, 1),
args_h.view(1, 1, height, 1, -1).repeat(batch_size, duration, 1, width, 1),
args_w.view(1, 1, 1, width, -1).repeat(batch_size, duration, height, 1, 1),
],
dim=-1,
)
cosine = torch.cos(args)
sine = torch.sin(args)
rope = torch.stack([cosine, -sine, sine, cosine], dim=-1)
rope = rope.view(*rope.shape[:-1], 2, 2)
return rope.unsqueeze(-4)
def reset_dtype(self):
for i, (axes_dim, ax_max_pos) in enumerate(zip(self.axes_dims, self.max_pos)):
freq = get_freqs(axes_dim // 2, self.max_period).to(self.args_0.device)
pos = torch.arange(ax_max_pos, dtype=freq.dtype, device=freq.device)
setattr(self, f'args_{i}', torch.outer(pos, freq))
class Modulation(nn.Module):
@@ -212,10 +304,11 @@ class Modulation(nn.Module):
self.out_layer.weight.data.zero_()
self.out_layer.bias.data.zero_()
@torch.autocast(device_type="cuda", dtype=torch.float32)
def forward(self, x):
return self.out_layer(self.activation(x))
class MultiheadSelfAttentionEnc(nn.Module):
def __init__(self, num_channels, head_dim):
super().__init__()
@@ -227,9 +320,10 @@ class MultiheadSelfAttentionEnc(nn.Module):
self.to_value = nn.Linear(num_channels, num_channels, bias=True)
self.query_norm = nn.RMSNorm(head_dim)
self.key_norm = nn.RMSNorm(head_dim)
self.out_layer = nn.Linear(num_channels, num_channels, bias=True)
def forward(self, x, rope):
def get_qkv(self, x):
query = self.to_query(x)
key = self.to_key(x)
value = self.to_value(x)
@@ -239,26 +333,31 @@ class MultiheadSelfAttentionEnc(nn.Module):
key = key.reshape(*shape, self.num_heads, -1)
value = value.reshape(*shape, self.num_heads, -1)
query = self.query_norm(query.float()).type_as(query)
key = self.key_norm(key.float()).type_as(key)
return query, key, value
def norm_qk(self, q, k):
q = self.query_norm(q.float()).type_as(q)
k = self.key_norm(k.float()).type_as(k)
return q, k
def scaled_dot_product_attention(self, query, key, value):
out = sdpa(q=query, k=key, v=value).flatten(-2, -1)
return out
def out_l(self, x):
return self.out_layer(x)
def forward(self, x, rope):
query, key, value = self.get_qkv(x)
query, key = self.norm_qk(query, key)
query = apply_rotary(query, rope).type_as(query)
key = apply_rotary(key, rope).type_as(key)
# Use torch's scaled_dot_product_attention
# print(query.shape, key.shape, value.shape, "QKV MultiheadSelfAttentionEnc SHAPE")
# out = F.scaled_dot_product_attention(
# query.permute(0, 2, 1, 3),
# key.permute(0, 2, 1, 3),
# value.permute(0, 2, 1, 3),
# ).permute(0, 2, 1, 3).flatten(-2, -1)
out = FA(q=query, k=key, v=value).flatten(-2, -1)
out = self.scaled_dot_product_attention(query, key, value)
out = self.out_layer(out)
out = self.out_l(out)
return out
class MultiheadSelfAttentionDec(nn.Module):
def __init__(self, num_channels, head_dim):
super().__init__()
@@ -270,9 +369,10 @@ class MultiheadSelfAttentionDec(nn.Module):
self.to_value = nn.Linear(num_channels, num_channels, bias=True)
self.query_norm = nn.RMSNorm(head_dim)
self.key_norm = nn.RMSNorm(head_dim)
self.out_layer = nn.Linear(num_channels, num_channels, bias=True)
def forward(self, x, rope, sparse_params=None):
def get_qkv(self, x):
query = self.to_query(x)
key = self.to_key(x)
value = self.to_value(x)
@@ -282,24 +382,29 @@ class MultiheadSelfAttentionDec(nn.Module):
key = key.reshape(*shape, self.num_heads, -1)
value = value.reshape(*shape, self.num_heads, -1)
query = self.query_norm(query.float()).type_as(query)
key = self.key_norm(key.float()).type_as(key)
return query, key, value
def norm_qk(self, q, k):
q = self.query_norm(q.float()).type_as(q)
k = self.key_norm(k.float()).type_as(k)
return q, k
def attention(self, query, key, value):
out = sdpa(q=query, k=key, v=value).flatten(-2, -1)
return out
def out_l(self, x):
return self.out_layer(x)
def forward(self, x, rope, sparse_params=None):
query, key, value = self.get_qkv(x)
query, key = self.norm_qk(query, key)
query = apply_rotary(query, rope).type_as(query)
key = apply_rotary(key, rope).type_as(key)
# Use standard attention (can be extended with sparse attention)
# out = F.scaled_dot_product_attention(
# query.permute(0, 2, 1, 3),
# key.permute(0, 2, 1, 3),
# value.permute(0, 2, 1, 3),
# ).permute(0, 2, 1, 3).flatten(-2, -1)
# print(query.shape, key.shape, value.shape, "QKV MultiheadSelfAttentionDec SHAPE")
out = FA(q=query, k=key, v=value).flatten(-2, -1)
out = self.attention(query, key, value)
out = self.out_layer(out)
out = self.out_l(out)
return out
@@ -314,32 +419,39 @@ class MultiheadCrossAttention(nn.Module):
self.to_value = nn.Linear(num_channels, num_channels, bias=True)
self.query_norm = nn.RMSNorm(head_dim)
self.key_norm = nn.RMSNorm(head_dim)
self.out_layer = nn.Linear(num_channels, num_channels, bias=True)
def forward(self, x, cond):
def get_qkv(self, x, cond):
query = self.to_query(x)
key = self.to_key(cond)
value = self.to_value(cond)
shape, cond_shape = query.shape[:-1], key.shape[:-1]
query = query.reshape(*shape, self.num_heads, -1)
key = key.reshape(*cond_shape, self.num_heads, -1)
value = value.reshape(*cond_shape, self.num_heads, -1)
query = self.query_norm(query.float()).type_as(query)
key = self.key_norm(key.float()).type_as(key)
# out = F.scaled_dot_product_attention(
# query.permute(0, 2, 1, 3),
# key.permute(0, 2, 1, 3),
# value.permute(0, 2, 1, 3),
# ).permute(0, 2, 1, 3).flatten(-2, -1)
# print(query.shape, key.shape, value.shape, "QKV MultiheadCrossAttention SHAPE")
out = FA(q=query, k=key, v=value).flatten(-2, -1)
return query, key, value
out = self.out_layer(out)
def norm_qk(self, q, k):
q = self.query_norm(q.float()).type_as(q)
k = self.key_norm(k.float()).type_as(k)
return q, k
def attention(self, query, key, value):
out = sdpa(q=query, k=key, v=value).flatten(-2, -1)
return out
def out_l(self, x):
return self.out_layer(x)
def forward(self, x, cond):
query, key, value = self.get_qkv(x, cond)
query, key = self.norm_qk(query, key)
out = self.attention(query, key, value)
out = self.out_l(out)
return out
@@ -354,6 +466,48 @@ class FeedForward(nn.Module):
return self.out_layer(self.activation(self.in_layer(x)))
class OutLayer(nn.Module):
def __init__(self, model_dim, time_dim, visual_dim, patch_size):
super().__init__()
self.patch_size = patch_size
self.modulation = Modulation(time_dim, model_dim, 2)
self.norm = nn.LayerNorm(model_dim, elementwise_affine=False)
self.out_layer = nn.Linear(
model_dim, math.prod(patch_size) * visual_dim, bias=True
)
def forward(self, visual_embed, text_embed, time_embed):
shift, scale = torch.chunk(self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1)
visual_embed = apply_scale_shift_norm(
self.norm,
visual_embed,
scale[:, None, None],
shift[:, None, None],
).type_as(visual_embed)
x = self.out_layer(visual_embed)
batch_size, duration, height, width, _ = x.shape
x = (
x.view(
batch_size,
duration,
height,
width,
-1,
self.patch_size[0],
self.patch_size[1],
self.patch_size[2],
)
.permute(0, 1, 5, 2, 6, 3, 7, 4)
.flatten(1, 2)
.flatten(2, 3)
.flatten(3, 4)
)
return x
class TransformerEncoderBlock(nn.Module):
def __init__(self, model_dim, time_dim, ff_dim, head_dim):
super().__init__()
@@ -366,9 +520,7 @@ class TransformerEncoderBlock(nn.Module):
self.feed_forward = FeedForward(model_dim, ff_dim)
def forward(self, x, time_embed, rope):
self_attn_params, ff_params = torch.chunk(
self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1
)
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1)
shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
out = self.self_attention(out, rope)
@@ -416,246 +568,116 @@ class TransformerDecoderBlock(nn.Module):
return visual_embed
class OutLayer(nn.Module):
def __init__(self, model_dim, time_dim, visual_dim, patch_size):
super().__init__()
self.patch_size = patch_size
self.modulation = Modulation(time_dim, model_dim, 2)
self.norm = nn.LayerNorm(model_dim, elementwise_affine=False)
self.out_layer = nn.Linear(
model_dim, math.prod(patch_size) * visual_dim, bias=True
)
def forward(self, visual_embed, text_embed, time_embed):
# Handle the new batch dimension: [batch, duration, height, width, model_dim]
batch_size, duration, height, width, _ = visual_embed.shape
shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1)
# Apply modulation with proper broadcasting for the new shape
visual_embed = apply_scale_shift_norm(
self.norm,
visual_embed,
scale[:, None, None, None], # [batch, 1, 1, 1, model_dim] -> [batch, 1, 1, 1]
shift[:, None, None, None], # [batch, 1, 1, 1, model_dim] -> [batch, 1, 1, 1]
).type_as(visual_embed)
x = self.out_layer(visual_embed)
# Now x has shape [batch, duration, height, width, patch_prod * visual_dim]
x = (
x.view(
batch_size,
duration,
height,
width,
-1,
self.patch_size[0],
self.patch_size[1],
self.patch_size[2],
)
.permute(0, 5, 1, 6, 2, 7, 3, 4) # [batch, patch_t, duration, patch_h, height, patch_w, width, features]
.flatten(1, 2) # [batch, patch_t * duration, height, patch_w, width, features]
.flatten(2, 3) # [batch, patch_t * duration, patch_h * height, width, features]
.flatten(3, 4) # [batch, patch_t * duration, patch_h * height, patch_w * width]
)
return x
@maybe_allow_in_graph
class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin):
r"""
A 3D Transformer model for video generation used in Kandinsky 5.0.
This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods implemented for all models (such as downloading or saving).
Args:
in_visual_dim (`int`, defaults to 16):
Number of channels in the input visual latent.
out_visual_dim (`int`, defaults to 16):
Number of channels in the output visual latent.
time_dim (`int`, defaults to 512):
Dimension of the time embeddings.
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
Patch size for the visual embeddings (temporal, height, width).
model_dim (`int`, defaults to 1792):
Hidden dimension of the transformer model.
ff_dim (`int`, defaults to 7168):
Intermediate dimension of the feed-forward networks.
num_text_blocks (`int`, defaults to 2):
Number of transformer blocks in the text encoder.
num_visual_blocks (`int`, defaults to 32):
Number of transformer blocks in the visual decoder.
axes_dims (`Tuple[int]`, defaults to `(16, 24, 24)`):
Dimensions for the rotary positional embeddings (temporal, height, width).
visual_cond (`bool`, defaults to `True`):
Whether to use visual conditioning (for image/video conditioning).
in_text_dim (`int`, defaults to 3584):
Dimension of the text embeddings from Qwen2.5-VL.
in_text_dim2 (`int`, defaults to 768):
Dimension of the pooled text embeddings from CLIP.
"""
A 3D Diffusion Transformer model for video-like data.
"""
@register_to_config
def __init__(
self,
in_visual_dim: int = 16,
out_visual_dim: int = 16,
time_dim: int = 512,
patch_size: Tuple[int, int, int] = (1, 2, 2),
model_dim: int = 1792,
ff_dim: int = 7168,
num_text_blocks: int = 2,
num_visual_blocks: int = 32,
axes_dims: Tuple[int, int, int] = (16, 24, 24),
visual_cond: bool = True,
in_text_dim: int = 3584,
in_text_dim2: int = 768,
in_visual_dim=4,
in_text_dim=3584,
in_text_dim2=768,
time_dim=512,
out_visual_dim=4,
patch_size=(1, 2, 2),
model_dim=2048,
ff_dim=5120,
num_text_blocks=2,
num_visual_blocks=32,
axes_dims=(16, 24, 24),
visual_cond=False,
):
super().__init__()
head_dim = sum(axes_dims)
self.in_visual_dim = in_visual_dim
self.model_dim = model_dim
self.patch_size = patch_size
self.visual_cond = visual_cond
# Calculate head dimension for attention
head_dim = sum(axes_dims)
# Determine visual embedding dimension based on conditioning
visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim
# 1. Embedding layers
self.time_embeddings = TimeEmbeddings(model_dim, time_dim)
self.text_embeddings = TextEmbeddings(in_text_dim, model_dim)
self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim)
self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size)
# 2. Rotary positional embeddings
self.text_rope_embeddings = RoPE1D(head_dim)
self.text_transformer_blocks = nn.ModuleList(
[
TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim)
for _ in range(num_text_blocks)
]
)
self.visual_rope_embeddings = RoPE3D(axes_dims)
self.visual_transformer_blocks = nn.ModuleList(
[
TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim)
for _ in range(num_visual_blocks)
]
)
# 3. Transformer blocks
self.text_transformer_blocks = nn.ModuleList([
TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim)
for _ in range(num_text_blocks)
])
self.visual_transformer_blocks = nn.ModuleList([
TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim)
for _ in range(num_visual_blocks)
])
# 4. Output layer
self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size)
self.gradient_checkpointing = False
def before_text_transformer_blocks(self, text_embed, time, pooled_text_embed, x,
text_rope_pos):
text_embed = self.text_embeddings(text_embed)
time_embed = self.time_embeddings(time)
time_embed = time_embed + self.pooled_text_embeddings(pooled_text_embed)
visual_embed = self.visual_embeddings(x)
text_rope = self.text_rope_embeddings(text_rope_pos)
text_rope = text_rope.unsqueeze(dim=0)
return text_embed, time_embed, text_rope, visual_embed
def before_visual_transformer_blocks(self, visual_embed, visual_rope_pos, scale_factor,
sparse_params):
visual_shape = visual_embed.shape[:-1]
visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor)
to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False
visual_embed, visual_rope = fractal_flatten(visual_embed, visual_rope, visual_shape,
block_mask=to_fractal)
return visual_embed, visual_shape, to_fractal, visual_rope
def after_blocks(self, visual_embed, visual_shape, to_fractal, text_embed, time_embed):
visual_embed = fractal_unflatten(visual_embed, visual_shape, block_mask=to_fractal)
x = self.out_layer(visual_embed, text_embed, time_embed)
return x
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
pooled_text_embed: torch.Tensor,
timestep: torch.Tensor,
visual_rope_pos: List[torch.Tensor],
text_rope_pos: torch.Tensor,
scale_factor: Tuple[float, float, float] = (1.0, 1.0, 1.0),
sparse_params: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
Forward pass of the Kandinsky 5.0 3D Transformer.
Args:
hidden_states (`torch.Tensor`):
Input visual latent tensor of shape `(batch_size, num_frames, height, width, channels)`.
encoder_hidden_states (`torch.Tensor`):
Text embeddings from Qwen2.5-VL of shape `(batch_size, sequence_length, text_dim)`.
pooled_text_embed (`torch.Tensor`):
Pooled text embeddings from CLIP of shape `(batch_size, pooled_text_dim)`.
timestep (`torch.Tensor`):
Timestep tensor of shape `(batch_size,)` or `(batch_size * num_frames,)`.
visual_rope_pos (`List[torch.Tensor]`):
List of tensors for visual rotary positional embeddings [temporal, height, width].
text_rope_pos (`torch.Tensor`):
Tensor for text rotary positional embeddings.
scale_factor (`Tuple[float, float, float]`, defaults to `(1.0, 1.0, 1.0)`):
Scale factors for rotary positional embeddings.
sparse_params (`Dict[str, Any]`, *optional*):
Parameters for sparse attention.
return_dict (`bool`, defaults to `True`):
Whether to return a dictionary or a tensor.
Returns:
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
If `return_dict` is `True`, a [`~models.transformer_2d.Transformer2DModelOutput`] is returned,
otherwise a `tuple` where the first element is the sample tensor.
"""
batch_size, num_frames, height, width, channels = hidden_states.shape
# 1. Process text embeddings
text_embed = self.text_embeddings(encoder_hidden_states)
time_embed = self.time_embeddings(timestep)
# Add pooled text embedding to time embedding
pooled_embed = self.pooled_text_embeddings(pooled_text_embed)
time_embed = time_embed + pooled_embed
# visual_embed shape: [batch_size, seq_len, model_dim]
visual_embed = self.visual_embeddings(hidden_states)
# 3. Text rotary embeddings
text_rope = self.text_rope_embeddings(text_rope_pos)
# 4. Text transformer blocks
i = 0
for text_block in self.text_transformer_blocks:
if self.gradient_checkpointing and self.training:
text_embed = torch.utils.checkpoint.checkpoint(
text_block, text_embed, time_embed, text_rope, use_reentrant=False
)
else:
text_embed = text_block(text_embed, time_embed, text_rope)
i += 1
# 5. Prepare visual rope
visual_shape = visual_embed.shape[:-1]
visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor)
# visual_embed = visual_embed.reshape(visual_embed.shape[0], -1, visual_embed.shape[-1])
# visual_rope = visual_rope.view(visual_rope.shape[0], -1, *list(visual_rope.shape[-4:]))
visual_embed = visual_embed.flatten(1, 3)
visual_rope = visual_rope.flatten(1, 3)
hidden_states, # x
encoder_hidden_states, #text_embed
timestep, # time
pooled_projections, #pooled_text_embed,
visual_rope_pos,
text_rope_pos,
scale_factor=(1.0, 1.0, 1.0),
sparse_params=None,
return_dict=True,
):
x = hidden_states
text_embed = encoder_hidden_states
time = timestep
pooled_text_embed = pooled_projections
# 6. Visual transformer blocks
i = 0
for visual_block in self.visual_transformer_blocks:
if self.gradient_checkpointing and self.training:
visual_embed = torch.utils.checkpoint.checkpoint(
visual_block,
visual_embed,
text_embed,
time_embed,
visual_rope,
# visual_rope_flat,
sparse_params,
use_reentrant=False,
)
else:
visual_embed = visual_block(
visual_embed, text_embed, time_embed, visual_rope, sparse_params
)
i += 1
text_embed, time_embed, text_rope, visual_embed = self.before_text_transformer_blocks(
text_embed, time, pooled_text_embed, x, text_rope_pos)
# 7. Output projection
visual_embed = visual_embed.reshape(batch_size, num_frames, height // 2, width // 2, -1)
output = self.out_layer(visual_embed, text_embed, time_embed)
for text_transformer_block in self.text_transformer_blocks:
text_embed = text_transformer_block(text_embed, time_embed, text_rope)
if not return_dict:
return (output,)
visual_embed, visual_shape, to_fractal, visual_rope = self.before_visual_transformer_blocks(
visual_embed, visual_rope_pos, scale_factor, sparse_params)
return Transformer2DModelOutput(sample=output)
for visual_transformer_block in self.visual_transformer_blocks:
visual_embed = visual_transformer_block(visual_embed, text_embed, time_embed,
visual_rope, sparse_params)
x = self.after_blocks(visual_embed, visual_shape, to_fractal, text_embed, time_embed)
if return_dict:
return Transformer2DModelOutput(sample=x)
return x

View File

@@ -300,7 +300,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 512,
width: int = 768,
num_frames: int = 25,
num_frames: int = 121,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
scheduler_scale: float = 10.0,
@@ -354,6 +354,11 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
the first element is a list with the generated images and the second element is a list of `bool`s
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""
self.transformer.time_embeddings.reset_dtype()
self.transformer.text_rope_embeddings.reset_dtype()
self.transformer.visual_rope_embeddings.reset_dtype()
dtype = self.transformer.dtype
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
@@ -394,7 +399,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
width=width,
num_frames=num_frames,
visual_cond=self.transformer.visual_cond,
dtype=self.transformer.dtype,
dtype=dtype,
device=device,
generator=generator,
latents=latents,
@@ -418,41 +423,39 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
timestep = t.unsqueeze(0)
timestep = t.unsqueeze(0).flatten()
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
# print(latents.shape)
with torch.autocast(device_type="cuda", dtype=dtype):
pred_velocity = self.transformer(
latents,
text_embeds["text_embeds"],
text_embeds["pooled_embed"],
timestep,
visual_rope_pos,
text_rope_pos,
hidden_states=latents,
encoder_hidden_states=text_embeds["text_embeds"],
pooled_projections=text_embeds["pooled_embed"],
timestep=timestep,
visual_rope_pos=visual_rope_pos,
text_rope_pos=text_rope_pos,
scale_factor=(1, 2, 2),
sparse_params=None,
return_dict=False
)[0]
return_dict=True
).sample
if guidance_scale > 1.0 and negative_text_embeds is not None:
uncond_pred_velocity = self.transformer(
latents,
negative_text_embeds["text_embeds"],
negative_text_embeds["pooled_embed"],
timestep,
visual_rope_pos,
negative_text_rope_pos,
hidden_states=latents,
encoder_hidden_states=negative_text_embeds["text_embeds"],
pooled_projections=negative_text_embeds["pooled_embed"],
timestep=timestep,
visual_rope_pos=visual_rope_pos,
text_rope_pos=negative_text_rope_pos,
scale_factor=(1, 2, 2),
sparse_params=None,
return_dict=False
)[0]
return_dict=True
).sample
pred_velocity = uncond_pred_velocity + guidance_scale * (
pred_velocity - uncond_pred_velocity
)
latents = self.scheduler.step(pred_velocity, t, latents[:, :, :, :, :16], return_dict=False)[0]
latents = torch.cat([latents, visual_cond], dim=-1)
latents[:, :, :, :, :16] = self.scheduler.step(pred_velocity, t, latents[:, :, :, :, :16], return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}