mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
remove unused imports
This commit is contained in:
@@ -19,21 +19,27 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import BoolTensor, IntTensor, Tensor, nn
|
||||
from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
|
||||
flex_attention)
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from torch.nn.attention.flex_attention import (
|
||||
BlockMask,
|
||||
flex_attention,
|
||||
)
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import (USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers,
|
||||
unscale_lora_layers)
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import (PixArtAlphaTextProjection, TimestepEmbedding,
|
||||
Timesteps, get_1d_rotary_pos_embed)
|
||||
from ..embeddings import (
|
||||
TimestepEmbedding,
|
||||
get_1d_rotary_pos_embed,
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import FP32LayerNorm
|
||||
@@ -95,7 +101,7 @@ def local_patching(x, shape, group_size, dim=0):
|
||||
g2,
|
||||
width // g3,
|
||||
g3,
|
||||
*x.shape[dim + 3 :]
|
||||
*x.shape[dim + 3 :],
|
||||
)
|
||||
x = x.permute(
|
||||
*range(len(x.shape[:dim])),
|
||||
@@ -105,7 +111,7 @@ def local_patching(x, shape, group_size, dim=0):
|
||||
dim + 1,
|
||||
dim + 3,
|
||||
dim + 5,
|
||||
*range(dim + 6, len(x.shape))
|
||||
*range(dim + 6, len(x.shape)),
|
||||
)
|
||||
x = x.flatten(dim, dim + 2).flatten(dim + 1, dim + 3)
|
||||
return x
|
||||
@@ -122,7 +128,7 @@ def local_merge(x, shape, group_size, dim=0):
|
||||
g1,
|
||||
g2,
|
||||
g3,
|
||||
*x.shape[dim + 2 :]
|
||||
*x.shape[dim + 2 :],
|
||||
)
|
||||
x = x.permute(
|
||||
*range(len(x.shape[:dim])),
|
||||
@@ -132,7 +138,7 @@ def local_merge(x, shape, group_size, dim=0):
|
||||
dim + 4,
|
||||
dim + 2,
|
||||
dim + 5,
|
||||
*range(dim + 6, len(x.shape))
|
||||
*range(dim + 6, len(x.shape)),
|
||||
)
|
||||
x = x.flatten(dim, dim + 1).flatten(dim + 1, dim + 2).flatten(dim + 2, dim + 3)
|
||||
return x
|
||||
@@ -172,15 +178,7 @@ def sdpa(q, k, v):
|
||||
query = q.transpose(1, 2).contiguous()
|
||||
key = k.transpose(1, 2).contiguous()
|
||||
value = v.transpose(1, 2).contiguous()
|
||||
out = (
|
||||
F.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value
|
||||
)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
out = F.scaled_dot_product_attention(query, key, value).transpose(1, 2).contiguous()
|
||||
return out
|
||||
|
||||
|
||||
@@ -279,7 +277,7 @@ class Kandinsky5RoPE1D(nn.Module):
|
||||
rope = torch.stack([cosine, -sine, sine, cosine], dim=-1)
|
||||
rope = rope.view(*rope.shape[:-1], 2, 2)
|
||||
return rope.unsqueeze(-4)
|
||||
|
||||
|
||||
def reset_dtype(self):
|
||||
freq = get_freqs(self.dim // 2, self.max_period).to(self.args.device)
|
||||
pos = torch.arange(self.max_pos, dtype=freq.dtype, device=freq.device)
|
||||
@@ -307,9 +305,15 @@ class Kandinsky5RoPE3D(nn.Module):
|
||||
|
||||
args = torch.cat(
|
||||
[
|
||||
args_t.view(1, duration, 1, 1, -1).repeat(batch_size, 1, height, width, 1),
|
||||
args_h.view(1, 1, height, 1, -1).repeat(batch_size, duration, 1, width, 1),
|
||||
args_w.view(1, 1, 1, width, -1).repeat(batch_size, duration, height, 1, 1),
|
||||
args_t.view(1, duration, 1, 1, -1).repeat(
|
||||
batch_size, 1, height, width, 1
|
||||
),
|
||||
args_h.view(1, 1, height, 1, -1).repeat(
|
||||
batch_size, duration, 1, width, 1
|
||||
),
|
||||
args_w.view(1, 1, 1, width, -1).repeat(
|
||||
batch_size, duration, height, 1, 1
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
@@ -318,12 +322,12 @@ class Kandinsky5RoPE3D(nn.Module):
|
||||
rope = torch.stack([cosine, -sine, sine, cosine], dim=-1)
|
||||
rope = rope.view(*rope.shape[:-1], 2, 2)
|
||||
return rope.unsqueeze(-4)
|
||||
|
||||
|
||||
def reset_dtype(self):
|
||||
for i, (axes_dim, ax_max_pos) in enumerate(zip(self.axes_dims, self.max_pos)):
|
||||
freq = get_freqs(axes_dim // 2, self.max_period).to(self.args_0.device)
|
||||
pos = torch.arange(ax_max_pos, dtype=freq.dtype, device=freq.device)
|
||||
setattr(self, f'args_{i}', torch.outer(pos, freq))
|
||||
setattr(self, f"args_{i}", torch.outer(pos, freq))
|
||||
|
||||
|
||||
class Kandinsky5Modulation(nn.Module):
|
||||
@@ -341,7 +345,7 @@ class Kandinsky5Modulation(nn.Module):
|
||||
|
||||
class Kandinsky5SDPAAttentionProcessor(nn.Module):
|
||||
"""Custom attention processor for standard SDPA attention"""
|
||||
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
@@ -357,7 +361,7 @@ class Kandinsky5SDPAAttentionProcessor(nn.Module):
|
||||
|
||||
class Kandinsky5NablaAttentionProcessor(nn.Module):
|
||||
"""Custom attention processor for Nabla attention"""
|
||||
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
@@ -369,11 +373,11 @@ class Kandinsky5NablaAttentionProcessor(nn.Module):
|
||||
):
|
||||
if sparse_params is None:
|
||||
raise ValueError("sparse_params is required for Nabla attention")
|
||||
|
||||
|
||||
query = query.transpose(1, 2).contiguous()
|
||||
key = key.transpose(1, 2).contiguous()
|
||||
value = value.transpose(1, 2).contiguous()
|
||||
|
||||
|
||||
block_mask = nablaT_v2(
|
||||
query,
|
||||
key,
|
||||
@@ -381,12 +385,7 @@ class Kandinsky5NablaAttentionProcessor(nn.Module):
|
||||
thr=sparse_params["P"],
|
||||
)
|
||||
out = (
|
||||
flex_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
block_mask=block_mask
|
||||
)
|
||||
flex_attention(query, key, value, block_mask=block_mask)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
@@ -407,7 +406,7 @@ class Kandinsky5MultiheadSelfAttentionEnc(nn.Module):
|
||||
self.key_norm = nn.RMSNorm(head_dim)
|
||||
|
||||
self.out_layer = nn.Linear(num_channels, num_channels, bias=True)
|
||||
|
||||
|
||||
# Initialize attention processor
|
||||
self.sdpa_processor = Kandinsky5SDPAAttentionProcessor()
|
||||
|
||||
@@ -430,13 +429,7 @@ class Kandinsky5MultiheadSelfAttentionEnc(nn.Module):
|
||||
|
||||
def scaled_dot_product_attention(self, query, key, value):
|
||||
# Use the processor
|
||||
return self.sdpa_processor(
|
||||
attn=self,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
**{}
|
||||
)
|
||||
return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{})
|
||||
|
||||
def out_l(self, x):
|
||||
return self.out_layer(x)
|
||||
@@ -466,7 +459,7 @@ class Kandinsky5MultiheadSelfAttentionDec(nn.Module):
|
||||
self.key_norm = nn.RMSNorm(head_dim)
|
||||
|
||||
self.out_layer = nn.Linear(num_channels, num_channels, bias=True)
|
||||
|
||||
|
||||
# Initialize attention processors
|
||||
self.sdpa_processor = Kandinsky5SDPAAttentionProcessor()
|
||||
self.nabla_processor = Kandinsky5NablaAttentionProcessor()
|
||||
@@ -490,14 +483,8 @@ class Kandinsky5MultiheadSelfAttentionDec(nn.Module):
|
||||
|
||||
def attention(self, query, key, value):
|
||||
# Use the processor
|
||||
return self.sdpa_processor(
|
||||
attn=self,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
**{}
|
||||
)
|
||||
|
||||
return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{})
|
||||
|
||||
def nabla(self, query, key, value, sparse_params=None):
|
||||
# Use the processor
|
||||
return self.nabla_processor(
|
||||
@@ -506,7 +493,7 @@ class Kandinsky5MultiheadSelfAttentionDec(nn.Module):
|
||||
key=key,
|
||||
value=value,
|
||||
sparse_params=sparse_params,
|
||||
**{}
|
||||
**{},
|
||||
)
|
||||
|
||||
def out_l(self, x):
|
||||
@@ -540,7 +527,7 @@ class Kandinsky5MultiheadCrossAttention(nn.Module):
|
||||
self.key_norm = nn.RMSNorm(head_dim)
|
||||
|
||||
self.out_layer = nn.Linear(num_channels, num_channels, bias=True)
|
||||
|
||||
|
||||
# Initialize attention processor
|
||||
self.sdpa_processor = Kandinsky5SDPAAttentionProcessor()
|
||||
|
||||
@@ -563,13 +550,7 @@ class Kandinsky5MultiheadCrossAttention(nn.Module):
|
||||
|
||||
def attention(self, query, key, value):
|
||||
# Use the processor
|
||||
return self.sdpa_processor(
|
||||
attn=self,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
**{}
|
||||
)
|
||||
return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{})
|
||||
|
||||
def out_l(self, x):
|
||||
return self.out_layer(x)
|
||||
@@ -605,7 +586,9 @@ class Kandinsky5OutLayer(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, visual_embed, text_embed, time_embed):
|
||||
shift, scale = torch.chunk(self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1)
|
||||
shift, scale = torch.chunk(
|
||||
self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1
|
||||
)
|
||||
visual_embed = apply_scale_shift_norm(
|
||||
self.norm,
|
||||
visual_embed,
|
||||
@@ -646,7 +629,9 @@ class Kandinsky5TransformerEncoderBlock(nn.Module):
|
||||
self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim)
|
||||
|
||||
def forward(self, x, time_embed, rope):
|
||||
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1)
|
||||
self_attn_params, ff_params = torch.chunk(
|
||||
self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1
|
||||
)
|
||||
shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
|
||||
out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
|
||||
out = self.self_attention(out, rope)
|
||||
@@ -678,26 +663,40 @@ class Kandinsky5TransformerDecoderBlock(nn.Module):
|
||||
self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1
|
||||
)
|
||||
shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
|
||||
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
|
||||
visual_out = apply_scale_shift_norm(
|
||||
self.self_attention_norm, visual_embed, scale, shift
|
||||
)
|
||||
visual_out = self.self_attention(visual_out, rope, sparse_params)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
|
||||
shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1)
|
||||
visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift)
|
||||
visual_out = apply_scale_shift_norm(
|
||||
self.cross_attention_norm, visual_embed, scale, shift
|
||||
)
|
||||
visual_out = self.cross_attention(visual_out, text_embed)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
|
||||
shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
|
||||
visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift)
|
||||
visual_out = apply_scale_shift_norm(
|
||||
self.feed_forward_norm, visual_embed, scale, shift
|
||||
)
|
||||
visual_out = self.feed_forward(visual_out)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
return visual_embed
|
||||
|
||||
|
||||
class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin):
|
||||
class Kandinsky5Transformer3DModel(
|
||||
ModelMixin,
|
||||
ConfigMixin,
|
||||
PeftAdapterMixin,
|
||||
FromOriginalModelMixin,
|
||||
CacheMixin,
|
||||
AttentionMixin,
|
||||
):
|
||||
"""
|
||||
A 3D Diffusion Transformer model for video-like data.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
@@ -714,21 +713,21 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr
|
||||
num_text_blocks=2,
|
||||
num_visual_blocks=32,
|
||||
axes_dims=(16, 24, 24),
|
||||
visual_cond=False,
|
||||
visual_cond=False,
|
||||
attention_type: str = "regular",
|
||||
attention_causal: bool = None, # Default for Nabla: false
|
||||
attention_local: bool = None, # Default for Nabla: false
|
||||
attention_glob: bool = None, # Default for Nabla: false
|
||||
attention_window: int = None, # Default for Nabla: 3
|
||||
attention_P: float = None, # Default for Nabla: 0.9
|
||||
attention_wT: int = None, # Default for Nabla: 11
|
||||
attention_wW: int = None, # Default for Nabla: 3
|
||||
attention_wH: int = None, # Default for Nabla: 3
|
||||
attention_add_sta: bool = None, # Default for Nabla: true
|
||||
attention_method: str = None, # Default for Nabla: "topcdf"
|
||||
attention_causal: bool = None, # Default for Nabla: false
|
||||
attention_local: bool = None, # Default for Nabla: false
|
||||
attention_glob: bool = None, # Default for Nabla: false
|
||||
attention_window: int = None, # Default for Nabla: 3
|
||||
attention_P: float = None, # Default for Nabla: 0.9
|
||||
attention_wT: int = None, # Default for Nabla: 11
|
||||
attention_wW: int = None, # Default for Nabla: 3
|
||||
attention_wH: int = None, # Default for Nabla: 3
|
||||
attention_add_sta: bool = None, # Default for Nabla: true
|
||||
attention_method: str = None, # Default for Nabla: "topcdf"
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
head_dim = sum(axes_dims)
|
||||
self.in_visual_dim = in_visual_dim
|
||||
self.model_dim = model_dim
|
||||
@@ -737,12 +736,14 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr
|
||||
self.attention_type = attention_type
|
||||
|
||||
visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim
|
||||
|
||||
|
||||
# Initialize embeddings
|
||||
self.time_embeddings = Kandinsky5TimeEmbeddings(model_dim, time_dim)
|
||||
self.text_embeddings = Kandinsky5TextEmbeddings(in_text_dim, model_dim)
|
||||
self.pooled_text_embeddings = Kandinsky5TextEmbeddings(in_text_dim2, time_dim)
|
||||
self.visual_embeddings = Kandinsky5VisualEmbeddings(visual_embed_dim, model_dim, patch_size)
|
||||
self.visual_embeddings = Kandinsky5VisualEmbeddings(
|
||||
visual_embed_dim, model_dim, patch_size
|
||||
)
|
||||
|
||||
# Initialize positional embeddings
|
||||
self.text_rope_embeddings = Kandinsky5RoPE1D(head_dim)
|
||||
@@ -764,10 +765,14 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr
|
||||
)
|
||||
|
||||
# Initialize output layer
|
||||
self.out_layer = Kandinsky5OutLayer(model_dim, time_dim, out_visual_dim, patch_size)
|
||||
self.out_layer = Kandinsky5OutLayer(
|
||||
model_dim, time_dim, out_visual_dim, patch_size
|
||||
)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def prepare_text_embeddings(self, text_embed, time, pooled_text_embed, x, text_rope_pos):
|
||||
def prepare_text_embeddings(
|
||||
self, text_embed, time, pooled_text_embed, x, text_rope_pos
|
||||
):
|
||||
"""Prepare text embeddings and related components"""
|
||||
text_embed = self.text_embeddings(text_embed)
|
||||
time_embed = self.time_embeddings(time)
|
||||
@@ -777,38 +782,58 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr
|
||||
text_rope = text_rope.unsqueeze(dim=0)
|
||||
return text_embed, time_embed, text_rope, visual_embed
|
||||
|
||||
def prepare_visual_embeddings(self, visual_embed, visual_rope_pos, scale_factor, sparse_params):
|
||||
def prepare_visual_embeddings(
|
||||
self, visual_embed, visual_rope_pos, scale_factor, sparse_params
|
||||
):
|
||||
"""Prepare visual embeddings and related components"""
|
||||
visual_shape = visual_embed.shape[:-1]
|
||||
visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor)
|
||||
visual_rope = self.visual_rope_embeddings(
|
||||
visual_shape, visual_rope_pos, scale_factor
|
||||
)
|
||||
to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False
|
||||
visual_embed, visual_rope = fractal_flatten(visual_embed, visual_rope, visual_shape,
|
||||
block_mask=to_fractal)
|
||||
visual_embed, visual_rope = fractal_flatten(
|
||||
visual_embed, visual_rope, visual_shape, block_mask=to_fractal
|
||||
)
|
||||
return visual_embed, visual_shape, to_fractal, visual_rope
|
||||
|
||||
def process_text_transformer_blocks(self, text_embed, time_embed, text_rope):
|
||||
"""Process text through transformer blocks"""
|
||||
for text_transformer_block in self.text_transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
text_embed = self._gradient_checkpointing_func(text_transformer_block, text_embed, time_embed, text_rope)
|
||||
text_embed = self._gradient_checkpointing_func(
|
||||
text_transformer_block, text_embed, time_embed, text_rope
|
||||
)
|
||||
else:
|
||||
text_embed = text_transformer_block(text_embed, time_embed, text_rope)
|
||||
return text_embed
|
||||
|
||||
def process_visual_transformer_blocks(self, visual_embed, text_embed, time_embed, visual_rope, sparse_params):
|
||||
def process_visual_transformer_blocks(
|
||||
self, visual_embed, text_embed, time_embed, visual_rope, sparse_params
|
||||
):
|
||||
"""Process visual through transformer blocks"""
|
||||
for visual_transformer_block in self.visual_transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
visual_embed = self._gradient_checkpointing_func(visual_transformer_block, visual_embed, text_embed, time_embed,
|
||||
visual_rope, sparse_params)
|
||||
visual_embed = self._gradient_checkpointing_func(
|
||||
visual_transformer_block,
|
||||
visual_embed,
|
||||
text_embed,
|
||||
time_embed,
|
||||
visual_rope,
|
||||
sparse_params,
|
||||
)
|
||||
else:
|
||||
visual_embed = visual_transformer_block(visual_embed, text_embed, time_embed,
|
||||
visual_rope, sparse_params)
|
||||
visual_embed = visual_transformer_block(
|
||||
visual_embed, text_embed, time_embed, visual_rope, sparse_params
|
||||
)
|
||||
return visual_embed
|
||||
|
||||
def prepare_output(self, visual_embed, visual_shape, to_fractal, text_embed, time_embed):
|
||||
def prepare_output(
|
||||
self, visual_embed, visual_shape, to_fractal, text_embed, time_embed
|
||||
):
|
||||
"""Prepare the final output"""
|
||||
visual_embed = fractal_unflatten(visual_embed, visual_shape, block_mask=to_fractal)
|
||||
visual_embed = fractal_unflatten(
|
||||
visual_embed, visual_shape, block_mask=to_fractal
|
||||
)
|
||||
x = self.out_layer(visual_embed, text_embed, time_embed)
|
||||
return x
|
||||
|
||||
@@ -846,25 +871,34 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr
|
||||
text_embed = encoder_hidden_states
|
||||
time = timestep
|
||||
pooled_text_embed = pooled_projections
|
||||
|
||||
|
||||
# Prepare text embeddings and related components
|
||||
text_embed, time_embed, text_rope, visual_embed = self.prepare_text_embeddings(
|
||||
text_embed, time, pooled_text_embed, x, text_rope_pos)
|
||||
text_embed, time, pooled_text_embed, x, text_rope_pos
|
||||
)
|
||||
|
||||
# Process text through transformer blocks
|
||||
text_embed = self.process_text_transformer_blocks(text_embed, time_embed, text_rope)
|
||||
text_embed = self.process_text_transformer_blocks(
|
||||
text_embed, time_embed, text_rope
|
||||
)
|
||||
|
||||
# Prepare visual embeddings and related components
|
||||
visual_embed, visual_shape, to_fractal, visual_rope = self.prepare_visual_embeddings(
|
||||
visual_embed, visual_rope_pos, scale_factor, sparse_params)
|
||||
visual_embed, visual_shape, to_fractal, visual_rope = (
|
||||
self.prepare_visual_embeddings(
|
||||
visual_embed, visual_rope_pos, scale_factor, sparse_params
|
||||
)
|
||||
)
|
||||
|
||||
# Process visual through transformer blocks
|
||||
visual_embed = self.process_visual_transformer_blocks(
|
||||
visual_embed, text_embed, time_embed, visual_rope, sparse_params)
|
||||
|
||||
visual_embed, text_embed, time_embed, visual_rope, sparse_params
|
||||
)
|
||||
|
||||
# Prepare final output
|
||||
x = self.prepare_output(visual_embed, visual_shape, to_fractal, text_embed, time_embed)
|
||||
|
||||
x = self.prepare_output(
|
||||
visual_embed, visual_shape, to_fractal, text_embed, time_embed
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return x
|
||||
|
||||
|
||||
Reference in New Issue
Block a user