mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix 5sec generation
This commit is contained in:
@@ -13,21 +13,27 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple, Union, List
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from einops import rearrange
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import BoolTensor, IntTensor, Tensor, nn
|
||||
from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
|
||||
flex_attention)
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import (USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers,
|
||||
unscale_lora_layers)
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
|
||||
from ..embeddings import (PixArtAlphaTextProjection, TimestepEmbedding,
|
||||
Timesteps, get_1d_rotary_pos_embed)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import FP32LayerNorm
|
||||
@@ -35,39 +41,14 @@ from ..normalization import FP32LayerNorm
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if torch.cuda.get_device_capability()[0] >= 9:
|
||||
try:
|
||||
from flash_attn_interface import flash_attn_func as FA
|
||||
except:
|
||||
FA = None
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_func as FA
|
||||
except:
|
||||
FA = None
|
||||
else:
|
||||
try:
|
||||
from flash_attn import flash_attn_func as FA
|
||||
except:
|
||||
FA = None
|
||||
def exist(item):
|
||||
return item is not None
|
||||
|
||||
|
||||
# @torch.compile()
|
||||
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
||||
def apply_scale_shift_norm(norm, x, scale, shift):
|
||||
return (norm(x) * (scale + 1.0) + shift).to(torch.bfloat16)
|
||||
|
||||
# @torch.compile()
|
||||
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
||||
def apply_gate_sum(x, out, gate):
|
||||
return (x + gate * out).to(torch.bfloat16)
|
||||
|
||||
# @torch.compile()
|
||||
@torch.autocast(device_type="cuda", enabled=False)
|
||||
def apply_rotary(x, rope):
|
||||
x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32)
|
||||
x_out = (rope * x_).sum(dim=-1)
|
||||
return x_out.reshape(*x.shape).to(torch.bfloat16)
|
||||
def freeze(model):
|
||||
for p in model.parameters():
|
||||
p.requires_grad = False
|
||||
return model
|
||||
|
||||
|
||||
@torch.autocast(device_type="cuda", enabled=False)
|
||||
@@ -80,6 +61,116 @@ def get_freqs(dim, max_period=10000.0):
|
||||
return freqs
|
||||
|
||||
|
||||
def fractal_flatten(x, rope, shape, block_mask=False):
|
||||
if block_mask:
|
||||
pixel_size = 8
|
||||
x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=0)
|
||||
rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=0)
|
||||
x = x.flatten(1, 2)
|
||||
rope = rope.flatten(1, 2)
|
||||
else:
|
||||
x = x.flatten(1, 3)
|
||||
rope = rope.flatten(1, 3)
|
||||
return x, rope
|
||||
|
||||
|
||||
def fractal_unflatten(x, shape, block_mask=False):
|
||||
if block_mask:
|
||||
pixel_size = 8
|
||||
x = x.reshape(-1, pixel_size**2, *x.shape[1:])
|
||||
x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=0)
|
||||
else:
|
||||
x = x.reshape(*shape, *x.shape[2:])
|
||||
return x
|
||||
|
||||
|
||||
def local_patching(x, shape, group_size, dim=0):
|
||||
duration, height, width = shape
|
||||
g1, g2, g3 = group_size
|
||||
x = x.reshape(
|
||||
*x.shape[:dim],
|
||||
duration // g1,
|
||||
g1,
|
||||
height // g2,
|
||||
g2,
|
||||
width // g3,
|
||||
g3,
|
||||
*x.shape[dim + 3 :]
|
||||
)
|
||||
x = x.permute(
|
||||
*range(len(x.shape[:dim])),
|
||||
dim,
|
||||
dim + 2,
|
||||
dim + 4,
|
||||
dim + 1,
|
||||
dim + 3,
|
||||
dim + 5,
|
||||
*range(dim + 6, len(x.shape))
|
||||
)
|
||||
x = x.flatten(dim, dim + 2).flatten(dim + 1, dim + 3)
|
||||
return x
|
||||
|
||||
|
||||
def local_merge(x, shape, group_size, dim=0):
|
||||
duration, height, width = shape
|
||||
g1, g2, g3 = group_size
|
||||
x = x.reshape(
|
||||
*x.shape[:dim],
|
||||
duration // g1,
|
||||
height // g2,
|
||||
width // g3,
|
||||
g1,
|
||||
g2,
|
||||
g3,
|
||||
*x.shape[dim + 2 :]
|
||||
)
|
||||
x = x.permute(
|
||||
*range(len(x.shape[:dim])),
|
||||
dim,
|
||||
dim + 3,
|
||||
dim + 1,
|
||||
dim + 4,
|
||||
dim + 2,
|
||||
dim + 5,
|
||||
*range(dim + 6, len(x.shape))
|
||||
)
|
||||
x = x.flatten(dim, dim + 1).flatten(dim + 1, dim + 2).flatten(dim + 2, dim + 3)
|
||||
return x
|
||||
|
||||
|
||||
def sdpa(q, k, v):
|
||||
query = q.transpose(1, 2).contiguous()
|
||||
key = k.transpose(1, 2).contiguous()
|
||||
value = v.transpose(1, 2).contiguous()
|
||||
out = (
|
||||
F.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value
|
||||
)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
||||
def apply_scale_shift_norm(norm, x, scale, shift):
|
||||
return (norm(x) * (scale + 1.0) + shift).to(torch.bfloat16)
|
||||
|
||||
|
||||
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
||||
def apply_gate_sum(x, out, gate):
|
||||
return (x + gate * out).to(torch.bfloat16)
|
||||
|
||||
|
||||
@torch.autocast(device_type="cuda", enabled=False)
|
||||
def apply_rotary(x, rope):
|
||||
x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32)
|
||||
x_out = (rope * x_).sum(dim=-1)
|
||||
return x_out.reshape(*x.shape).to(torch.bfloat16)
|
||||
|
||||
|
||||
class TimeEmbeddings(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, max_period=10000.0):
|
||||
super().__init__()
|
||||
@@ -93,12 +184,16 @@ class TimeEmbeddings(nn.Module):
|
||||
self.activation = nn.SiLU()
|
||||
self.out_layer = nn.Linear(time_dim, time_dim, bias=True)
|
||||
|
||||
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
||||
def forward(self, time):
|
||||
args = torch.outer(time, self.freqs.to(device=time.device))
|
||||
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
|
||||
return time_embed
|
||||
|
||||
def reset_dtype(self):
|
||||
self.freqs = get_freqs(self.model_dim // 2, self.max_period)
|
||||
|
||||
|
||||
class TextEmbeddings(nn.Module):
|
||||
def __init__(self, text_dim, model_dim):
|
||||
@@ -116,7 +211,7 @@ class VisualEmbeddings(nn.Module):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.in_layer = nn.Linear(math.prod(patch_size) * visual_dim, model_dim)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, duration, height, width, dim = x.shape
|
||||
x = (
|
||||
@@ -124,7 +219,7 @@ class VisualEmbeddings(nn.Module):
|
||||
batch_size,
|
||||
duration // self.patch_size[0],
|
||||
self.patch_size[0],
|
||||
height // self.patch_size[1],
|
||||
height // self.patch_size[1],
|
||||
self.patch_size[1],
|
||||
width // self.patch_size[2],
|
||||
self.patch_size[2],
|
||||
@@ -137,15 +232,6 @@ class VisualEmbeddings(nn.Module):
|
||||
|
||||
|
||||
class RoPE1D(nn.Module):
|
||||
"""
|
||||
1D Rotary Positional Embeddings for text sequences.
|
||||
|
||||
Args:
|
||||
dim: Dimension of the rotary embeddings
|
||||
max_pos: Maximum sequence length
|
||||
max_period: Maximum period for sinusoidal embeddings
|
||||
"""
|
||||
|
||||
def __init__(self, dim, max_pos=1024, max_period=10000.0):
|
||||
super().__init__()
|
||||
self.max_period = max_period
|
||||
@@ -153,22 +239,21 @@ class RoPE1D(nn.Module):
|
||||
self.max_pos = max_pos
|
||||
freq = get_freqs(dim // 2, max_period)
|
||||
pos = torch.arange(max_pos, dtype=freq.dtype)
|
||||
self.register_buffer("args", torch.outer(pos, freq), persistent=False)
|
||||
self.register_buffer(f"args", torch.outer(pos, freq), persistent=False)
|
||||
|
||||
@torch.autocast(device_type="cuda", enabled=False)
|
||||
def forward(self, pos):
|
||||
"""
|
||||
Args:
|
||||
pos: Position indices of shape [seq_len] or [batch_size, seq_len]
|
||||
|
||||
Returns:
|
||||
Rotary embeddings of shape [seq_len, 1, 2, 2]
|
||||
"""
|
||||
args = self.args[pos]
|
||||
cosine = torch.cos(args)
|
||||
sine = torch.sin(args)
|
||||
rope = torch.stack([cosine, -sine, sine, cosine], dim=-1)
|
||||
rope = rope.view(*rope.shape[:-1], 2, 2)
|
||||
return rope.unsqueeze(-4)
|
||||
|
||||
def reset_dtype(self):
|
||||
freq = get_freqs(self.dim // 2, self.max_period).to(self.args.device)
|
||||
pos = torch.arange(self.max_pos, dtype=freq.dtype, device=freq.device)
|
||||
self.args = torch.outer(pos, freq)
|
||||
|
||||
|
||||
class RoPE3D(nn.Module):
|
||||
@@ -186,22 +271,29 @@ class RoPE3D(nn.Module):
|
||||
@torch.autocast(device_type="cuda", enabled=False)
|
||||
def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)):
|
||||
batch_size, duration, height, width = shape
|
||||
|
||||
args_t = self.args_0[pos[0]] / scale_factor[0]
|
||||
args_h = self.args_1[pos[1]] / scale_factor[1]
|
||||
args_w = self.args_2[pos[2]] / scale_factor[2]
|
||||
|
||||
args_t_expanded = args_t.view(1, duration, 1, 1, -1).expand(batch_size, -1, height, width, -1)
|
||||
args_h_expanded = args_h.view(1, 1, height, 1, -1).expand(batch_size, duration, -1, width, -1)
|
||||
args_w_expanded = args_w.view(1, 1, 1, width, -1).expand(batch_size, duration, height, -1, -1)
|
||||
|
||||
args = torch.cat([args_t_expanded, args_h_expanded, args_w_expanded], dim=-1)
|
||||
|
||||
args = torch.cat(
|
||||
[
|
||||
args_t.view(1, duration, 1, 1, -1).repeat(batch_size, 1, height, width, 1),
|
||||
args_h.view(1, 1, height, 1, -1).repeat(batch_size, duration, 1, width, 1),
|
||||
args_w.view(1, 1, 1, width, -1).repeat(batch_size, duration, height, 1, 1),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
cosine = torch.cos(args)
|
||||
sine = torch.sin(args)
|
||||
rope = torch.stack([cosine, -sine, sine, cosine], dim=-1)
|
||||
rope = rope.view(*rope.shape[:-1], 2, 2)
|
||||
return rope.unsqueeze(-4)
|
||||
|
||||
def reset_dtype(self):
|
||||
for i, (axes_dim, ax_max_pos) in enumerate(zip(self.axes_dims, self.max_pos)):
|
||||
freq = get_freqs(axes_dim // 2, self.max_period).to(self.args_0.device)
|
||||
pos = torch.arange(ax_max_pos, dtype=freq.dtype, device=freq.device)
|
||||
setattr(self, f'args_{i}', torch.outer(pos, freq))
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
@@ -212,10 +304,11 @@ class Modulation(nn.Module):
|
||||
self.out_layer.weight.data.zero_()
|
||||
self.out_layer.bias.data.zero_()
|
||||
|
||||
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
||||
def forward(self, x):
|
||||
return self.out_layer(self.activation(x))
|
||||
|
||||
|
||||
|
||||
class MultiheadSelfAttentionEnc(nn.Module):
|
||||
def __init__(self, num_channels, head_dim):
|
||||
super().__init__()
|
||||
@@ -227,9 +320,10 @@ class MultiheadSelfAttentionEnc(nn.Module):
|
||||
self.to_value = nn.Linear(num_channels, num_channels, bias=True)
|
||||
self.query_norm = nn.RMSNorm(head_dim)
|
||||
self.key_norm = nn.RMSNorm(head_dim)
|
||||
|
||||
self.out_layer = nn.Linear(num_channels, num_channels, bias=True)
|
||||
|
||||
def forward(self, x, rope):
|
||||
def get_qkv(self, x):
|
||||
query = self.to_query(x)
|
||||
key = self.to_key(x)
|
||||
value = self.to_value(x)
|
||||
@@ -239,26 +333,31 @@ class MultiheadSelfAttentionEnc(nn.Module):
|
||||
key = key.reshape(*shape, self.num_heads, -1)
|
||||
value = value.reshape(*shape, self.num_heads, -1)
|
||||
|
||||
query = self.query_norm(query.float()).type_as(query)
|
||||
key = self.key_norm(key.float()).type_as(key)
|
||||
return query, key, value
|
||||
|
||||
def norm_qk(self, q, k):
|
||||
q = self.query_norm(q.float()).type_as(q)
|
||||
k = self.key_norm(k.float()).type_as(k)
|
||||
return q, k
|
||||
|
||||
def scaled_dot_product_attention(self, query, key, value):
|
||||
out = sdpa(q=query, k=key, v=value).flatten(-2, -1)
|
||||
return out
|
||||
|
||||
def out_l(self, x):
|
||||
return self.out_layer(x)
|
||||
|
||||
def forward(self, x, rope):
|
||||
query, key, value = self.get_qkv(x)
|
||||
query, key = self.norm_qk(query, key)
|
||||
query = apply_rotary(query, rope).type_as(query)
|
||||
key = apply_rotary(key, rope).type_as(key)
|
||||
|
||||
# Use torch's scaled_dot_product_attention
|
||||
# print(query.shape, key.shape, value.shape, "QKV MultiheadSelfAttentionEnc SHAPE")
|
||||
# out = F.scaled_dot_product_attention(
|
||||
# query.permute(0, 2, 1, 3),
|
||||
# key.permute(0, 2, 1, 3),
|
||||
# value.permute(0, 2, 1, 3),
|
||||
# ).permute(0, 2, 1, 3).flatten(-2, -1)
|
||||
|
||||
out = FA(q=query, k=key, v=value).flatten(-2, -1)
|
||||
out = self.scaled_dot_product_attention(query, key, value)
|
||||
|
||||
out = self.out_layer(out)
|
||||
out = self.out_l(out)
|
||||
return out
|
||||
|
||||
|
||||
class MultiheadSelfAttentionDec(nn.Module):
|
||||
def __init__(self, num_channels, head_dim):
|
||||
super().__init__()
|
||||
@@ -270,9 +369,10 @@ class MultiheadSelfAttentionDec(nn.Module):
|
||||
self.to_value = nn.Linear(num_channels, num_channels, bias=True)
|
||||
self.query_norm = nn.RMSNorm(head_dim)
|
||||
self.key_norm = nn.RMSNorm(head_dim)
|
||||
|
||||
self.out_layer = nn.Linear(num_channels, num_channels, bias=True)
|
||||
|
||||
def forward(self, x, rope, sparse_params=None):
|
||||
def get_qkv(self, x):
|
||||
query = self.to_query(x)
|
||||
key = self.to_key(x)
|
||||
value = self.to_value(x)
|
||||
@@ -282,24 +382,29 @@ class MultiheadSelfAttentionDec(nn.Module):
|
||||
key = key.reshape(*shape, self.num_heads, -1)
|
||||
value = value.reshape(*shape, self.num_heads, -1)
|
||||
|
||||
query = self.query_norm(query.float()).type_as(query)
|
||||
key = self.key_norm(key.float()).type_as(key)
|
||||
return query, key, value
|
||||
|
||||
def norm_qk(self, q, k):
|
||||
q = self.query_norm(q.float()).type_as(q)
|
||||
k = self.key_norm(k.float()).type_as(k)
|
||||
return q, k
|
||||
|
||||
def attention(self, query, key, value):
|
||||
out = sdpa(q=query, k=key, v=value).flatten(-2, -1)
|
||||
return out
|
||||
|
||||
def out_l(self, x):
|
||||
return self.out_layer(x)
|
||||
|
||||
def forward(self, x, rope, sparse_params=None):
|
||||
query, key, value = self.get_qkv(x)
|
||||
query, key = self.norm_qk(query, key)
|
||||
query = apply_rotary(query, rope).type_as(query)
|
||||
key = apply_rotary(key, rope).type_as(key)
|
||||
|
||||
# Use standard attention (can be extended with sparse attention)
|
||||
# out = F.scaled_dot_product_attention(
|
||||
# query.permute(0, 2, 1, 3),
|
||||
# key.permute(0, 2, 1, 3),
|
||||
# value.permute(0, 2, 1, 3),
|
||||
# ).permute(0, 2, 1, 3).flatten(-2, -1)
|
||||
|
||||
# print(query.shape, key.shape, value.shape, "QKV MultiheadSelfAttentionDec SHAPE")
|
||||
|
||||
out = FA(q=query, k=key, v=value).flatten(-2, -1)
|
||||
out = self.attention(query, key, value)
|
||||
|
||||
out = self.out_layer(out)
|
||||
out = self.out_l(out)
|
||||
return out
|
||||
|
||||
|
||||
@@ -314,32 +419,39 @@ class MultiheadCrossAttention(nn.Module):
|
||||
self.to_value = nn.Linear(num_channels, num_channels, bias=True)
|
||||
self.query_norm = nn.RMSNorm(head_dim)
|
||||
self.key_norm = nn.RMSNorm(head_dim)
|
||||
|
||||
self.out_layer = nn.Linear(num_channels, num_channels, bias=True)
|
||||
|
||||
def forward(self, x, cond):
|
||||
def get_qkv(self, x, cond):
|
||||
query = self.to_query(x)
|
||||
key = self.to_key(cond)
|
||||
value = self.to_value(cond)
|
||||
|
||||
|
||||
shape, cond_shape = query.shape[:-1], key.shape[:-1]
|
||||
query = query.reshape(*shape, self.num_heads, -1)
|
||||
key = key.reshape(*cond_shape, self.num_heads, -1)
|
||||
value = value.reshape(*cond_shape, self.num_heads, -1)
|
||||
|
||||
query = self.query_norm(query.float()).type_as(query)
|
||||
key = self.key_norm(key.float()).type_as(key)
|
||||
|
||||
# out = F.scaled_dot_product_attention(
|
||||
# query.permute(0, 2, 1, 3),
|
||||
# key.permute(0, 2, 1, 3),
|
||||
# value.permute(0, 2, 1, 3),
|
||||
# ).permute(0, 2, 1, 3).flatten(-2, -1)
|
||||
|
||||
# print(query.shape, key.shape, value.shape, "QKV MultiheadCrossAttention SHAPE")
|
||||
|
||||
out = FA(q=query, k=key, v=value).flatten(-2, -1)
|
||||
return query, key, value
|
||||
|
||||
out = self.out_layer(out)
|
||||
def norm_qk(self, q, k):
|
||||
q = self.query_norm(q.float()).type_as(q)
|
||||
k = self.key_norm(k.float()).type_as(k)
|
||||
return q, k
|
||||
|
||||
def attention(self, query, key, value):
|
||||
out = sdpa(q=query, k=key, v=value).flatten(-2, -1)
|
||||
return out
|
||||
|
||||
def out_l(self, x):
|
||||
return self.out_layer(x)
|
||||
|
||||
def forward(self, x, cond):
|
||||
query, key, value = self.get_qkv(x, cond)
|
||||
query, key = self.norm_qk(query, key)
|
||||
|
||||
out = self.attention(query, key, value)
|
||||
out = self.out_l(out)
|
||||
return out
|
||||
|
||||
|
||||
@@ -354,6 +466,48 @@ class FeedForward(nn.Module):
|
||||
return self.out_layer(self.activation(self.in_layer(x)))
|
||||
|
||||
|
||||
class OutLayer(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, visual_dim, patch_size):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.modulation = Modulation(time_dim, model_dim, 2)
|
||||
self.norm = nn.LayerNorm(model_dim, elementwise_affine=False)
|
||||
self.out_layer = nn.Linear(
|
||||
model_dim, math.prod(patch_size) * visual_dim, bias=True
|
||||
)
|
||||
|
||||
def forward(self, visual_embed, text_embed, time_embed):
|
||||
shift, scale = torch.chunk(self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1)
|
||||
visual_embed = apply_scale_shift_norm(
|
||||
self.norm,
|
||||
visual_embed,
|
||||
scale[:, None, None],
|
||||
shift[:, None, None],
|
||||
).type_as(visual_embed)
|
||||
x = self.out_layer(visual_embed)
|
||||
|
||||
batch_size, duration, height, width, _ = x.shape
|
||||
x = (
|
||||
x.view(
|
||||
batch_size,
|
||||
duration,
|
||||
height,
|
||||
width,
|
||||
-1,
|
||||
self.patch_size[0],
|
||||
self.patch_size[1],
|
||||
self.patch_size[2],
|
||||
)
|
||||
.permute(0, 1, 5, 2, 6, 3, 7, 4)
|
||||
.flatten(1, 2)
|
||||
.flatten(2, 3)
|
||||
.flatten(3, 4)
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
|
||||
class TransformerEncoderBlock(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, ff_dim, head_dim):
|
||||
super().__init__()
|
||||
@@ -366,9 +520,7 @@ class TransformerEncoderBlock(nn.Module):
|
||||
self.feed_forward = FeedForward(model_dim, ff_dim)
|
||||
|
||||
def forward(self, x, time_embed, rope):
|
||||
self_attn_params, ff_params = torch.chunk(
|
||||
self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1
|
||||
)
|
||||
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1)
|
||||
shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
|
||||
out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
|
||||
out = self.self_attention(out, rope)
|
||||
@@ -416,246 +568,116 @@ class TransformerDecoderBlock(nn.Module):
|
||||
return visual_embed
|
||||
|
||||
|
||||
class OutLayer(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, visual_dim, patch_size):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.modulation = Modulation(time_dim, model_dim, 2)
|
||||
self.norm = nn.LayerNorm(model_dim, elementwise_affine=False)
|
||||
self.out_layer = nn.Linear(
|
||||
model_dim, math.prod(patch_size) * visual_dim, bias=True
|
||||
)
|
||||
|
||||
def forward(self, visual_embed, text_embed, time_embed):
|
||||
# Handle the new batch dimension: [batch, duration, height, width, model_dim]
|
||||
batch_size, duration, height, width, _ = visual_embed.shape
|
||||
|
||||
shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1)
|
||||
|
||||
# Apply modulation with proper broadcasting for the new shape
|
||||
visual_embed = apply_scale_shift_norm(
|
||||
self.norm,
|
||||
visual_embed,
|
||||
scale[:, None, None, None], # [batch, 1, 1, 1, model_dim] -> [batch, 1, 1, 1]
|
||||
shift[:, None, None, None], # [batch, 1, 1, 1, model_dim] -> [batch, 1, 1, 1]
|
||||
).type_as(visual_embed)
|
||||
|
||||
x = self.out_layer(visual_embed)
|
||||
|
||||
# Now x has shape [batch, duration, height, width, patch_prod * visual_dim]
|
||||
x = (
|
||||
x.view(
|
||||
batch_size,
|
||||
duration,
|
||||
height,
|
||||
width,
|
||||
-1,
|
||||
self.patch_size[0],
|
||||
self.patch_size[1],
|
||||
self.patch_size[2],
|
||||
)
|
||||
.permute(0, 5, 1, 6, 2, 7, 3, 4) # [batch, patch_t, duration, patch_h, height, patch_w, width, features]
|
||||
.flatten(1, 2) # [batch, patch_t * duration, height, patch_w, width, features]
|
||||
.flatten(2, 3) # [batch, patch_t * duration, patch_h * height, width, features]
|
||||
.flatten(3, 4) # [batch, patch_t * duration, patch_h * height, patch_w * width]
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A 3D Transformer model for video generation used in Kandinsky 5.0.
|
||||
|
||||
This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
||||
methods implemented for all models (such as downloading or saving).
|
||||
|
||||
Args:
|
||||
in_visual_dim (`int`, defaults to 16):
|
||||
Number of channels in the input visual latent.
|
||||
out_visual_dim (`int`, defaults to 16):
|
||||
Number of channels in the output visual latent.
|
||||
time_dim (`int`, defaults to 512):
|
||||
Dimension of the time embeddings.
|
||||
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
|
||||
Patch size for the visual embeddings (temporal, height, width).
|
||||
model_dim (`int`, defaults to 1792):
|
||||
Hidden dimension of the transformer model.
|
||||
ff_dim (`int`, defaults to 7168):
|
||||
Intermediate dimension of the feed-forward networks.
|
||||
num_text_blocks (`int`, defaults to 2):
|
||||
Number of transformer blocks in the text encoder.
|
||||
num_visual_blocks (`int`, defaults to 32):
|
||||
Number of transformer blocks in the visual decoder.
|
||||
axes_dims (`Tuple[int]`, defaults to `(16, 24, 24)`):
|
||||
Dimensions for the rotary positional embeddings (temporal, height, width).
|
||||
visual_cond (`bool`, defaults to `True`):
|
||||
Whether to use visual conditioning (for image/video conditioning).
|
||||
in_text_dim (`int`, defaults to 3584):
|
||||
Dimension of the text embeddings from Qwen2.5-VL.
|
||||
in_text_dim2 (`int`, defaults to 768):
|
||||
Dimension of the pooled text embeddings from CLIP.
|
||||
"""
|
||||
|
||||
A 3D Diffusion Transformer model for video-like data.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_visual_dim: int = 16,
|
||||
out_visual_dim: int = 16,
|
||||
time_dim: int = 512,
|
||||
patch_size: Tuple[int, int, int] = (1, 2, 2),
|
||||
model_dim: int = 1792,
|
||||
ff_dim: int = 7168,
|
||||
num_text_blocks: int = 2,
|
||||
num_visual_blocks: int = 32,
|
||||
axes_dims: Tuple[int, int, int] = (16, 24, 24),
|
||||
visual_cond: bool = True,
|
||||
in_text_dim: int = 3584,
|
||||
in_text_dim2: int = 768,
|
||||
in_visual_dim=4,
|
||||
in_text_dim=3584,
|
||||
in_text_dim2=768,
|
||||
time_dim=512,
|
||||
out_visual_dim=4,
|
||||
patch_size=(1, 2, 2),
|
||||
model_dim=2048,
|
||||
ff_dim=5120,
|
||||
num_text_blocks=2,
|
||||
num_visual_blocks=32,
|
||||
axes_dims=(16, 24, 24),
|
||||
visual_cond=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
head_dim = sum(axes_dims)
|
||||
self.in_visual_dim = in_visual_dim
|
||||
self.model_dim = model_dim
|
||||
self.patch_size = patch_size
|
||||
self.visual_cond = visual_cond
|
||||
|
||||
# Calculate head dimension for attention
|
||||
head_dim = sum(axes_dims)
|
||||
|
||||
# Determine visual embedding dimension based on conditioning
|
||||
visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim
|
||||
|
||||
# 1. Embedding layers
|
||||
self.time_embeddings = TimeEmbeddings(model_dim, time_dim)
|
||||
self.text_embeddings = TextEmbeddings(in_text_dim, model_dim)
|
||||
self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim)
|
||||
self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size)
|
||||
|
||||
# 2. Rotary positional embeddings
|
||||
self.text_rope_embeddings = RoPE1D(head_dim)
|
||||
self.text_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim)
|
||||
for _ in range(num_text_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.visual_rope_embeddings = RoPE3D(axes_dims)
|
||||
self.visual_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim)
|
||||
for _ in range(num_visual_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
# 3. Transformer blocks
|
||||
self.text_transformer_blocks = nn.ModuleList([
|
||||
TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim)
|
||||
for _ in range(num_text_blocks)
|
||||
])
|
||||
|
||||
self.visual_transformer_blocks = nn.ModuleList([
|
||||
TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim)
|
||||
for _ in range(num_visual_blocks)
|
||||
])
|
||||
|
||||
# 4. Output layer
|
||||
self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
def before_text_transformer_blocks(self, text_embed, time, pooled_text_embed, x,
|
||||
text_rope_pos):
|
||||
text_embed = self.text_embeddings(text_embed)
|
||||
time_embed = self.time_embeddings(time)
|
||||
time_embed = time_embed + self.pooled_text_embeddings(pooled_text_embed)
|
||||
visual_embed = self.visual_embeddings(x)
|
||||
text_rope = self.text_rope_embeddings(text_rope_pos)
|
||||
text_rope = text_rope.unsqueeze(dim=0)
|
||||
return text_embed, time_embed, text_rope, visual_embed
|
||||
|
||||
def before_visual_transformer_blocks(self, visual_embed, visual_rope_pos, scale_factor,
|
||||
sparse_params):
|
||||
visual_shape = visual_embed.shape[:-1]
|
||||
visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor)
|
||||
to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False
|
||||
visual_embed, visual_rope = fractal_flatten(visual_embed, visual_rope, visual_shape,
|
||||
block_mask=to_fractal)
|
||||
return visual_embed, visual_shape, to_fractal, visual_rope
|
||||
|
||||
def after_blocks(self, visual_embed, visual_shape, to_fractal, text_embed, time_embed):
|
||||
visual_embed = fractal_unflatten(visual_embed, visual_shape, block_mask=to_fractal)
|
||||
x = self.out_layer(visual_embed, text_embed, time_embed)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
pooled_text_embed: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
visual_rope_pos: List[torch.Tensor],
|
||||
text_rope_pos: torch.Tensor,
|
||||
scale_factor: Tuple[float, float, float] = (1.0, 1.0, 1.0),
|
||||
sparse_params: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
Forward pass of the Kandinsky 5.0 3D Transformer.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor`):
|
||||
Input visual latent tensor of shape `(batch_size, num_frames, height, width, channels)`.
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
Text embeddings from Qwen2.5-VL of shape `(batch_size, sequence_length, text_dim)`.
|
||||
pooled_text_embed (`torch.Tensor`):
|
||||
Pooled text embeddings from CLIP of shape `(batch_size, pooled_text_dim)`.
|
||||
timestep (`torch.Tensor`):
|
||||
Timestep tensor of shape `(batch_size,)` or `(batch_size * num_frames,)`.
|
||||
visual_rope_pos (`List[torch.Tensor]`):
|
||||
List of tensors for visual rotary positional embeddings [temporal, height, width].
|
||||
text_rope_pos (`torch.Tensor`):
|
||||
Tensor for text rotary positional embeddings.
|
||||
scale_factor (`Tuple[float, float, float]`, defaults to `(1.0, 1.0, 1.0)`):
|
||||
Scale factors for rotary positional embeddings.
|
||||
sparse_params (`Dict[str, Any]`, *optional*):
|
||||
Parameters for sparse attention.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether to return a dictionary or a tensor.
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, a [`~models.transformer_2d.Transformer2DModelOutput`] is returned,
|
||||
otherwise a `tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
batch_size, num_frames, height, width, channels = hidden_states.shape
|
||||
|
||||
# 1. Process text embeddings
|
||||
text_embed = self.text_embeddings(encoder_hidden_states)
|
||||
time_embed = self.time_embeddings(timestep)
|
||||
|
||||
# Add pooled text embedding to time embedding
|
||||
pooled_embed = self.pooled_text_embeddings(pooled_text_embed)
|
||||
time_embed = time_embed + pooled_embed
|
||||
|
||||
# visual_embed shape: [batch_size, seq_len, model_dim]
|
||||
visual_embed = self.visual_embeddings(hidden_states)
|
||||
|
||||
# 3. Text rotary embeddings
|
||||
text_rope = self.text_rope_embeddings(text_rope_pos)
|
||||
|
||||
# 4. Text transformer blocks
|
||||
i = 0
|
||||
for text_block in self.text_transformer_blocks:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
text_embed = torch.utils.checkpoint.checkpoint(
|
||||
text_block, text_embed, time_embed, text_rope, use_reentrant=False
|
||||
)
|
||||
|
||||
else:
|
||||
text_embed = text_block(text_embed, time_embed, text_rope)
|
||||
|
||||
i += 1
|
||||
|
||||
# 5. Prepare visual rope
|
||||
visual_shape = visual_embed.shape[:-1]
|
||||
visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor)
|
||||
|
||||
# visual_embed = visual_embed.reshape(visual_embed.shape[0], -1, visual_embed.shape[-1])
|
||||
# visual_rope = visual_rope.view(visual_rope.shape[0], -1, *list(visual_rope.shape[-4:]))
|
||||
visual_embed = visual_embed.flatten(1, 3)
|
||||
visual_rope = visual_rope.flatten(1, 3)
|
||||
hidden_states, # x
|
||||
encoder_hidden_states, #text_embed
|
||||
timestep, # time
|
||||
pooled_projections, #pooled_text_embed,
|
||||
visual_rope_pos,
|
||||
text_rope_pos,
|
||||
scale_factor=(1.0, 1.0, 1.0),
|
||||
sparse_params=None,
|
||||
return_dict=True,
|
||||
):
|
||||
x = hidden_states
|
||||
text_embed = encoder_hidden_states
|
||||
time = timestep
|
||||
pooled_text_embed = pooled_projections
|
||||
|
||||
# 6. Visual transformer blocks
|
||||
i = 0
|
||||
for visual_block in self.visual_transformer_blocks:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
visual_embed = torch.utils.checkpoint.checkpoint(
|
||||
visual_block,
|
||||
visual_embed,
|
||||
text_embed,
|
||||
time_embed,
|
||||
visual_rope,
|
||||
# visual_rope_flat,
|
||||
sparse_params,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
visual_embed = visual_block(
|
||||
visual_embed, text_embed, time_embed, visual_rope, sparse_params
|
||||
)
|
||||
|
||||
i += 1
|
||||
text_embed, time_embed, text_rope, visual_embed = self.before_text_transformer_blocks(
|
||||
text_embed, time, pooled_text_embed, x, text_rope_pos)
|
||||
|
||||
# 7. Output projection
|
||||
visual_embed = visual_embed.reshape(batch_size, num_frames, height // 2, width // 2, -1)
|
||||
output = self.out_layer(visual_embed, text_embed, time_embed)
|
||||
for text_transformer_block in self.text_transformer_blocks:
|
||||
text_embed = text_transformer_block(text_embed, time_embed, text_rope)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
visual_embed, visual_shape, to_fractal, visual_rope = self.before_visual_transformer_blocks(
|
||||
visual_embed, visual_rope_pos, scale_factor, sparse_params)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
for visual_transformer_block in self.visual_transformer_blocks:
|
||||
visual_embed = visual_transformer_block(visual_embed, text_embed, time_embed,
|
||||
visual_rope, sparse_params)
|
||||
|
||||
x = self.after_blocks(visual_embed, visual_shape, to_fractal, text_embed, time_embed)
|
||||
|
||||
if return_dict:
|
||||
return Transformer2DModelOutput(sample=x)
|
||||
|
||||
return x
|
||||
|
||||
@@ -300,7 +300,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 512,
|
||||
width: int = 768,
|
||||
num_frames: int = 25,
|
||||
num_frames: int = 121,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
scheduler_scale: float = 10.0,
|
||||
@@ -354,6 +354,11 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
the first element is a list with the generated images and the second element is a list of `bool`s
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
self.transformer.time_embeddings.reset_dtype()
|
||||
self.transformer.text_rope_embeddings.reset_dtype()
|
||||
self.transformer.visual_rope_embeddings.reset_dtype()
|
||||
|
||||
dtype = self.transformer.dtype
|
||||
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
@@ -394,7 +399,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
visual_cond=self.transformer.visual_cond,
|
||||
dtype=self.transformer.dtype,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
@@ -418,41 +423,39 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
timestep = t.unsqueeze(0)
|
||||
timestep = t.unsqueeze(0).flatten()
|
||||
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
# print(latents.shape)
|
||||
with torch.autocast(device_type="cuda", dtype=dtype):
|
||||
pred_velocity = self.transformer(
|
||||
latents,
|
||||
text_embeds["text_embeds"],
|
||||
text_embeds["pooled_embed"],
|
||||
timestep,
|
||||
visual_rope_pos,
|
||||
text_rope_pos,
|
||||
hidden_states=latents,
|
||||
encoder_hidden_states=text_embeds["text_embeds"],
|
||||
pooled_projections=text_embeds["pooled_embed"],
|
||||
timestep=timestep,
|
||||
visual_rope_pos=visual_rope_pos,
|
||||
text_rope_pos=text_rope_pos,
|
||||
scale_factor=(1, 2, 2),
|
||||
sparse_params=None,
|
||||
return_dict=False
|
||||
)[0]
|
||||
|
||||
return_dict=True
|
||||
).sample
|
||||
|
||||
if guidance_scale > 1.0 and negative_text_embeds is not None:
|
||||
uncond_pred_velocity = self.transformer(
|
||||
latents,
|
||||
negative_text_embeds["text_embeds"],
|
||||
negative_text_embeds["pooled_embed"],
|
||||
timestep,
|
||||
visual_rope_pos,
|
||||
negative_text_rope_pos,
|
||||
hidden_states=latents,
|
||||
encoder_hidden_states=negative_text_embeds["text_embeds"],
|
||||
pooled_projections=negative_text_embeds["pooled_embed"],
|
||||
timestep=timestep,
|
||||
visual_rope_pos=visual_rope_pos,
|
||||
text_rope_pos=negative_text_rope_pos,
|
||||
scale_factor=(1, 2, 2),
|
||||
sparse_params=None,
|
||||
return_dict=False
|
||||
)[0]
|
||||
return_dict=True
|
||||
).sample
|
||||
|
||||
pred_velocity = uncond_pred_velocity + guidance_scale * (
|
||||
pred_velocity - uncond_pred_velocity
|
||||
)
|
||||
|
||||
latents = self.scheduler.step(pred_velocity, t, latents[:, :, :, :, :16], return_dict=False)[0]
|
||||
latents = torch.cat([latents, visual_cond], dim=-1)
|
||||
latents[:, :, :, :, :16] = self.scheduler.step(pred_velocity, t, latents[:, :, :, :, :16], return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
|
||||
Reference in New Issue
Block a user