1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Move Video and Audio Text Encoder Connectors to Transformer (#12)

* Denormalize audio latents in I2V pipeline (analogous to T2V change)

* Initial refactor to put video and audio text encoder connectors in transformer

* Get LTX 2 transformer tests working after connector refactor

* precompute run_connectors,.

* fixes

* Address review comments

* Calculate RoPE double precisions freqs using torch instead of np

* Further simplify LTX 2 RoPE freq calc

* Make connectors a separate module (#18)

* remove text_encoder.py

* address yiyi's comments.

* up

* up

* up

* up

---------

Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
This commit is contained in:
dg845
2026-01-05 06:41:13 -08:00
committed by GitHub
parent aae70b90db
commit caae16768a
8 changed files with 629 additions and 815 deletions

View File

@@ -8,18 +8,11 @@ import safetensors.torch
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from transformers import AutoModel, AutoTokenizer
from transformers import AutoModel, AutoTokenizer, Gemma3ForConditionalGeneration
from diffusers import (
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
FlowMatchEulerDiscreteScheduler,
LTX2Pipeline,
LTX2VideoTransformer3DModel,
)
from diffusers import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, FlowMatchEulerDiscreteScheduler, LTX2Pipeline, LTX2VideoTransformer3DModel
from diffusers.pipelines.ltx2 import LTX2TextConnectors, LTX2Vocoder
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.pipelines.ltx2.text_encoder import LTX2AudioVisualTextEncoder
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
CTX = init_empty_weights if is_accelerate_available() else nullcontext
@@ -134,6 +127,17 @@ LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
"adaln_single": convert_ltx2_transformer_adaln_single,
}
LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = {
"connectors.": "",
"video_embeddings_connector": "video_connector",
"audio_embeddings_connector": "audio_connector",
"transformer_1d_blocks": "transformer_blocks",
"text_embedding_projection.aggregate_embed": "text_proj_in",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_inplace,
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
@@ -146,7 +150,27 @@ LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
LTX_2_0_TEXT_ENCODER_SPECIAL_KEYS_REMAP = {}
def split_transformer_and_connector_state_dict(state_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
connector_prefixes = (
"video_embeddings_connector",
"audio_embeddings_connector",
"transformer_1d_blocks",
"text_embedding_projection.aggregate_embed",
"connectors.",
"video_connector",
"audio_connector",
"text_proj_in",
)
transformer_state_dict, connector_state_dict = {}, {}
for key, value in state_dict.items():
if key.startswith(connector_prefixes):
connector_state_dict[key] = value
else:
transformer_state_dict[key] = value
return transformer_state_dict, connector_state_dict
def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
@@ -240,32 +264,109 @@ def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str,
return config, rename_dict, special_keys_remap
def get_ltx2_connectors_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "test":
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"diffusers_config": {
"caption_channels": 16,
"text_proj_in_factor": 3,
"video_connector_num_attention_heads": 4,
"video_connector_attention_head_dim": 8,
"video_connector_num_layers": 1,
"video_connector_num_learnable_registers": None,
"audio_connector_num_attention_heads": 4,
"audio_connector_attention_head_dim": 8,
"audio_connector_num_layers": 1,
"audio_connector_num_learnable_registers": None,
"connector_rope_base_seq_len": 32,
"rope_theta": 10000.0,
"rope_double_precision": False,
"causal_temporal_positioning": False,
},
}
elif version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"caption_channels": 3840,
"text_proj_in_factor": 49,
"video_connector_num_attention_heads": 30,
"video_connector_attention_head_dim": 128,
"video_connector_num_layers": 2,
"video_connector_num_learnable_registers": 128,
"audio_connector_num_attention_heads": 30,
"audio_connector_attention_head_dim": 128,
"audio_connector_num_layers": 2,
"audio_connector_num_learnable_registers": 128,
"connector_rope_base_seq_len": 4096,
"rope_theta": 10000.0,
"rope_double_precision": True,
"causal_temporal_positioning": False,
},
}
rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
special_keys_remap = {}
return config, rename_dict, special_keys_remap
def convert_ltx2_transformer(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_transformer_config(version)
diffusers_config = config["diffusers_config"]
transformer_state_dict, _ = split_transformer_and_connector_state_dict(original_state_dict)
with init_empty_weights():
transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
for key in list(transformer_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
update_state_dict_inplace(transformer_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for key in list(transformer_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
handler_fn_inplace(key, transformer_state_dict)
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
transformer.load_state_dict(transformer_state_dict, strict=True, assign=True)
return transformer
def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -> LTX2TextConnectors:
config, rename_dict, special_keys_remap = get_ltx2_connectors_config(version)
diffusers_config = config["diffusers_config"]
_, connector_state_dict = split_transformer_and_connector_state_dict(original_state_dict)
if len(connector_state_dict) == 0:
raise ValueError("No connector weights found in the provided state dict.")
with init_empty_weights():
connectors = LTX2TextConnectors.from_config(diffusers_config)
for key in list(connector_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(connector_state_dict, key, new_key)
for key in list(connector_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, connector_state_dict)
connectors.load_state_dict(connector_state_dict, strict=True, assign=True)
return connectors
def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "test":
config = {
@@ -471,81 +572,6 @@ def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> D
return vocoder
def get_ltx2_text_encoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"text_encoder_hidden_dim": 3840,
"text_proj_in_factor": 49,
"video_connector_num_attention_heads": 30,
"video_connector_attention_head_dim": 128,
"video_connector_num_layers": 2,
"video_connector_num_learnable_registers": 128,
"audio_connector_num_attention_heads": 30,
"audio_connector_attention_head_dim": 128,
"audio_connector_num_layers": 2,
"audio_connector_num_learnable_registers": 128,
"rope_base_seq_len": 4096,
"rope_theta": 10000.0,
"rope_double_precision": True,
"causal_temporal_positioning": False,
},
}
rename_dict = LTX_2_0_TEXT_ENCODER_RENAME_DICT
special_keys_remap = LTX_2_0_TEXT_ENCODER_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def get_text_encoder_keys_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str = "model.diffusion_model."):
model_state_dict = {}
model_state_dict["text_proj_in.weight"] = combined_ckpt["text_embedding_projection.aggregate_embed.weight"]
text_encoder_submodules = ["video_embeddings_connector", "audio_embeddings_connector"]
for param_name, param in combined_ckpt.items():
if param_name.startswith(prefix):
new_param_name = param_name.replace(prefix, "")
for submodule_name in text_encoder_submodules:
if new_param_name.startswith(submodule_name):
model_state_dict[new_param_name] = param
break
return model_state_dict
def convert_ltx2_text_encoder(original_state_dict: Dict[str, Any], version: str, text_model_id: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_text_encoder_config(version)
diffusers_config = config["diffusers_config"]
diffusers_config["text_model_id"] = text_model_id
diffusers_config["config_only"] = True
with init_empty_weights():
text_encoder = LTX2AudioVisualTextEncoder.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
base_text_model = AutoModel.from_pretrained(text_model_id)
base_text_model_state_dict= base_text_model.state_dict()
base_text_model_state_dict = {"base_text_encoder." + k: v for k, v in base_text_model_state_dict.items()}
combined_state_dict = {**original_state_dict, **base_text_model_state_dict}
text_encoder.load_state_dict(combined_state_dict, strict=True, assign=True)
return text_encoder
def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]:
if args.original_state_dict_repo_id is not None:
@@ -588,6 +614,13 @@ def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefi
for param_name, param in combined_ckpt.items():
if param_name.startswith(prefix):
model_state_dict[param_name.replace(prefix, "")] = param
if prefix == "model.diffusion_model.":
# Some checkpoints store the text connector projection outside the diffusion model prefix.
connector_key = "text_embedding_projection.aggregate_embed.weight"
if connector_key in combined_ckpt and connector_key not in model_state_dict:
model_state_dict[connector_key] = combined_ckpt[connector_key]
return model_state_dict
@@ -649,6 +682,7 @@ def get_args():
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
parser.add_argument("--connectors", action="store_true", help="Whether to convert the connector model")
parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model")
parser.add_argument("--text_encoder", action="store_true", help="Whether to conver the text encoder")
parser.add_argument(
@@ -721,6 +755,15 @@ def main(args):
transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version)
if not args.full_pipeline:
transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer"))
if args.connectors or args.full_pipeline:
if args.dit_filename is not None:
original_connectors_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
elif combined_ckpt is not None:
original_connectors_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
connectors = convert_ltx2_connectors(original_connectors_ckpt, version=args.version)
if not args.full_pipeline:
connectors.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "connectors"))
if args.vocoder or args.full_pipeline:
if args.vocoder_filename is not None:
@@ -732,8 +775,8 @@ def main(args):
vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder"))
if args.text_encoder or args.full_pipeline:
text_encoder_ckpt = get_text_encoder_keys_from_combined_ckpt(combined_ckpt)
text_encoder = convert_ltx2_text_encoder(text_encoder_ckpt, args.version, args.text_encoder_model_id)
# text_encoder = AutoModel.from_pretrained(args.text_encoder_model_id)
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(args.text_encoder_model_id)
if not args.full_pipeline:
text_encoder.to(text_encoder_dtype).save_pretrained(os.path.join(args.output_path, "text_encoder"))
@@ -758,6 +801,7 @@ def main(args):
audio_vae=audio_vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
connectors=connectors,
transformer=transformer,
vocoder=vocoder,
)

View File

@@ -14,11 +14,9 @@
# limitations under the License.
import inspect
import math
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
@@ -780,28 +778,12 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
num_rope_elems = num_pos_dims * 2
# 4. Create a 1D grid of frequencies for RoPE
start = 1.0
end = self.theta
if self.double_precision:
pow_indices = np.power(
self.theta,
np.linspace(
np.log(start) / np.log(self.theta),
np.log(end) / np.log(self.theta),
self.dim // num_rope_elems,
dtype=np.float64,
),
)
freqs = torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32, device=device)
else:
freqs = self.theta ** torch.linspace(
start=math.log(start, self.theta),
end=math.log(end, self.theta),
steps=self.dim // num_rope_elems,
device=device,
dtype=torch.float32,
)
freqs = freqs * math.pi / 2.0
freqs_dtype = torch.float64 if self.double_precision else torch.float32
pow_indices = torch.pow(
self.theta,
torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device),
)
freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32)
# 5. Tensor-vector outer product between pos ids tensor of shape (B, 3, num_patches) and freqs vector of shape
# (self.dim // num_elems,)

View File

@@ -24,7 +24,7 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["pipeline_ltx2"] = ["LTX2Pipeline"]
_import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"]
_import_structure["text_encoder"] = ["LTX2AudioVisualTextEncoder"]
_import_structure["connectors"] = ["LTX2TextConnectors"]
_import_structure["vocoder"] = ["LTX2Vocoder"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -37,7 +37,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .pipeline_ltx2 import LTX2Pipeline
from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline
from .text_encoder import LTX2AudioVisualTextEncoder
from .connectors import LTX2TextConnectors
from .vocoder import LTX2Vocoder
else:

View File

@@ -0,0 +1,281 @@
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.attention import FeedForward
from ...models.modeling_utils import ModelMixin
from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor
class LTX2RotaryPosEmbed1d(nn.Module):
"""
1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors.
"""
def __init__(
self,
dim: int,
base_seq_len: int = 4096,
theta: float = 10000.0,
double_precision: bool = True,
):
super().__init__()
self.dim = dim
self.base_seq_len = base_seq_len
self.theta = theta
self.double_precision = double_precision
def forward(
self,
batch_size: int,
pos: int,
device: Union[str, torch.device],
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Get 1D position ids
grid_1d = torch.arange(pos, dtype=torch.float32, device=device)
# Get fractional indices relative to self.base_seq_len
grid_1d = grid_1d / self.base_seq_len
grid = grid_1d.unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len]
# 2. Calculate 1D RoPE frequencies
num_rope_elems = 2 # 1 (because 1D) * 2 (for cos, sin) = 2
freqs_dtype = torch.float64 if self.double_precision else torch.float32
pow_indices = torch.pow(
self.theta,
torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device),
)
freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32)
# 3. Matrix-vector outer product between pos ids of shape (batch_size, seq_len) and freqs vector of shape
# (self.dim // 2,).
freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, seq_len, self.dim // 2]
# 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim
cos_freqs = freqs.cos().repeat_interleave(2, dim=-1)
sin_freqs = freqs.sin().repeat_interleave(2, dim=-1)
if self.dim % num_rope_elems != 0:
cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems])
sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems])
cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
return cos_freqs, sin_freqs
class LTX2TransformerBlock1d(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
activation_fn: str = "gelu-approximate",
eps: float = 1e-6,
):
super().__init__()
self.norm1 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
self.attn1 = LTX2Attention(
query_dim=dim,
heads=num_attention_heads,
kv_heads=num_attention_heads,
dim_head=attention_head_dim,
processor=LTX2AudioVideoAttnProcessor(),
)
self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
self.ff = FeedForward(dim, activation_fn=activation_fn)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
norm_hidden_states = self.norm1(hidden_states)
attn_hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, query_rotary_emb=rotary_emb)
hidden_states = hidden_states + attn_hidden_states
norm_hidden_states = self.norm2(hidden_states)
ff_hidden_states = self.ff(norm_hidden_states)
hidden_states = hidden_states + ff_hidden_states
return hidden_states
class LTX2ConnectorTransformer1d(nn.Module):
"""
A 1D sequence transformer for modalities such as text.
In LTX 2.0, this is used to process the text encoder hidden states for each of the video and audio streams.
"""
_supports_gradient_checkpointing = True
def __init__(
self,
num_attention_heads: int = 30,
attention_head_dim: int = 128,
num_layers: int = 2,
num_learnable_registers: int | None = 128,
rope_base_seq_len: int = 4096,
rope_theta: float = 10000.0,
rope_double_precision: bool = True,
eps: float = 1e-6,
causal_temporal_positioning: bool = False,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.inner_dim = num_attention_heads * attention_head_dim
self.causal_temporal_positioning = causal_temporal_positioning
self.num_learnable_registers = num_learnable_registers
self.learnable_registers = None
if num_learnable_registers is not None:
init_registers = torch.rand(num_learnable_registers, self.inner_dim) * 2.0 - 1.0
self.learnable_registers = torch.nn.Parameter(init_registers)
self.rope = LTX2RotaryPosEmbed1d(
self.inner_dim, base_seq_len=rope_base_seq_len, theta=rope_theta, double_precision=rope_double_precision
)
self.transformer_blocks = torch.nn.ModuleList(
[
LTX2TransformerBlock1d(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for _ in range(num_layers)
]
)
self.norm_out = torch.nn.RMSNorm(self.inner_dim, eps=eps, elementwise_affine=False)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attn_mask_binarize_threshold: float = -9000.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
# hidden_states shape: [batch_size, seq_len, hidden_dim]
# attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len]
batch_size, seq_len, _ = hidden_states.shape
# 1. Replace padding with learned registers, if using
if self.learnable_registers is not None:
if seq_len % self.num_learnable_registers != 0:
raise ValueError(
f"The `hidden_states` sequence length {hidden_states.shape[1]} should be divisible by the number"
f" of learnable registers {self.num_learnable_registers}"
)
num_register_repeats = seq_len // self.num_learnable_registers
registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim]
binary_attn_mask = (attention_mask >= attn_mask_binarize_threshold).int()
if binary_attn_mask.ndim == 4:
binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L]
hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)]
valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded]
pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens]
padded_hidden_states = [
F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths)
]
padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D]
flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1]
hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers
# Overwrite attention_mask with an all-zeros mask if using registers.
attention_mask = torch.zeros_like(attention_mask)
# 2. Calculate 1D RoPE positional embeddings
rotary_emb = self.rope(batch_size, seq_len, device=hidden_states.device)
# 3. Run 1D transformer blocks
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(block, hidden_states, attention_mask, rotary_emb)
else:
hidden_states = block(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb)
hidden_states = self.norm_out(hidden_states)
return hidden_states, attention_mask
class LTX2TextConnectors(ModelMixin, ConfigMixin):
"""
Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and
audio streams.
"""
@register_to_config
def __init__(
self,
caption_channels: int,
text_proj_in_factor: int,
video_connector_num_attention_heads: int,
video_connector_attention_head_dim: int,
video_connector_num_layers: int,
video_connector_num_learnable_registers: int | None,
audio_connector_num_attention_heads: int,
audio_connector_attention_head_dim: int,
audio_connector_num_layers: int,
audio_connector_num_learnable_registers: int | None,
connector_rope_base_seq_len: int,
rope_theta: float,
rope_double_precision: bool,
causal_temporal_positioning: bool,
):
super().__init__()
self.text_proj_in = nn.Linear(caption_channels * text_proj_in_factor, caption_channels, bias=False)
self.video_connector = LTX2ConnectorTransformer1d(
num_attention_heads=video_connector_num_attention_heads,
attention_head_dim=video_connector_attention_head_dim,
num_layers=video_connector_num_layers,
num_learnable_registers=video_connector_num_learnable_registers,
rope_base_seq_len=connector_rope_base_seq_len,
rope_theta=rope_theta,
rope_double_precision=rope_double_precision,
causal_temporal_positioning=causal_temporal_positioning,
)
self.audio_connector = LTX2ConnectorTransformer1d(
num_attention_heads=audio_connector_num_attention_heads,
attention_head_dim=audio_connector_attention_head_dim,
num_layers=audio_connector_num_layers,
num_learnable_registers=audio_connector_num_learnable_registers,
rope_base_seq_len=connector_rope_base_seq_len,
rope_theta=rope_theta,
rope_double_precision=rope_double_precision,
causal_temporal_positioning=causal_temporal_positioning,
)
def forward(
self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor, additive_mask: bool = False
):
# Convert to additive attention mask, if necessary
if not additive_mask:
text_dtype = text_encoder_hidden_states.dtype
attention_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
attention_mask = attention_mask.to(text_dtype) * torch.finfo(text_dtype).max
text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states)
video_text_embedding, new_attn_mask = self.video_connector(text_encoder_hidden_states, attention_mask)
attn_mask = (new_attn_mask < 1e-6).to(torch.int64)
attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1)
video_text_embedding = video_text_embedding * attn_mask
new_attn_mask = attn_mask.squeeze(-1)
audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, attention_mask)
return video_text_embedding, audio_text_embedding, new_attn_mask

View File

@@ -29,8 +29,8 @@ from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .connectors import LTX2TextConnectors
from .pipeline_output import LTX2PipelineOutput
from .text_encoder import LTX2AudioVisualTextEncoder
from .vocoder import LTX2Vocoder
@@ -192,9 +192,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
tokenizer (`T5TokenizerFast`):
Second Tokenizer of class
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
connectors ([`LTX2TextConnectors`]):
Text connector stack used to adapt text encoder hidden states for the video and audio branches.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae->audio_vae->vocoder"
model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
@@ -203,8 +205,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKLLTX2Video,
audio_vae: AutoencoderKLLTX2Audio,
text_encoder: LTX2AudioVisualTextEncoder,
text_encoder: Gemma3ForConditionalGeneration,
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
connectors: LTX2TextConnectors,
transformer: LTX2VideoTransformer3DModel,
vocoder: LTX2Vocoder,
):
@@ -215,6 +218,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
audio_vae=audio_vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
connectors=connectors,
transformer=transformer,
vocoder=vocoder,
scheduler=scheduler,
@@ -252,6 +256,73 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024
)
@staticmethod
def _pack_text_embeds(
text_hidden_states: torch.Tensor,
sequence_lengths: torch.Tensor,
device: Union[str, torch.device],
padding_side: str = "left",
scale_factor: int = 8,
eps: float = 1e-6,
) -> torch.Tensor:
"""
Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and
per-layer in a masked fashion (only over non-padded positions).
Args:
text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
The number of valid (non-padded) tokens for each batch instance.
device: (`str` or `torch.device`, *optional*):
torch device to place the resulting embeddings on
padding_side: (`str`, *optional*, defaults to `"left"`):
Whether the text tokenizer performs padding on the `"left"` or `"right"`.
scale_factor (`int`, *optional*, defaults to `8`):
Scaling factor to multiply the normalized hidden states by.
eps (`float`, *optional*, defaults to `1e-6`):
A small positive value for numerical stability when performing normalization.
Returns:
`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
Normed and flattened text encoder hidden states.
"""
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
original_dtype = text_hidden_states.dtype
# Create padding mask
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
if padding_side == "right":
# For right padding, valid tokens are from 0 to sequence_length-1
mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len]
elif padding_side == "left":
# For left padding, valid tokens are from (T - sequence_length) to T-1
start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1]
mask = token_indices >= start_indices # [B, T]
else:
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1]
# Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len)
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
# Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len)
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
# Normalization
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
normalized_hidden_states = normalized_hidden_states * scale_factor
# Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers)
normalized_hidden_states = normalized_hidden_states.flatten(2)
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
return normalized_hidden_states
def _get_gemma_prompt_embeds(
self,
prompt: Union[str, List[str]],
@@ -274,7 +345,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt.
"""
device = device or self._execution_device
dtype = dtype or self.text_encoder.base_text_encoder.dtype
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
@@ -296,29 +367,34 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
text_input_ids = text_input_ids.to(device)
prompt_attention_mask = prompt_attention_mask.to(device)
prompt_embeds, audio_prompt_embeds, prompt_attention_mask = self.text_encoder(
text_input_ids.to(device),
attention_mask=prompt_attention_mask.to(device),
text_encoder_outputs = self.text_encoder(
input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
)
text_encoder_hidden_states = text_encoder_outputs.hidden_states
text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
sequence_lengths = prompt_attention_mask.sum(dim=-1)
prompt_embeds = self._pack_text_embeds(
text_encoder_hidden_states,
sequence_lengths,
device=device,
padding_side=self.tokenizer.padding_side,
scale_factor=scale_factor,
)
prompt_embeds = prompt_embeds.to(dtype=dtype)
audio_prompt_embeds = audio_prompt_embeds.to(dtype=dtype)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
_, audio_seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, audio_seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
return prompt_embeds, audio_prompt_embeds, prompt_attention_mask
return prompt_embeds, prompt_attention_mask
def encode_prompt(
self,
@@ -327,9 +403,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
audio_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_audio_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 1024,
@@ -372,7 +446,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds, audio_prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
@@ -397,7 +471,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
" the batch size of `prompt`."
)
negative_prompt_embeds, negative_audio_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
@@ -406,7 +480,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
dtype=dtype,
)
return prompt_embeds, audio_prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_audio_prompt_embeds, negative_prompt_attention_mask
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
def check_inputs(
self,
@@ -668,10 +742,8 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
latents: Optional[torch.Tensor] = None,
audio_latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
audio_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_audio_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
decode_timestep: Union[float, List[float]] = 0.0,
decode_noise_scale: Optional[Union[float, List[float]]] = None,
@@ -732,17 +804,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
audio_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings for audio processing. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Pre-generated attention mask for text embeddings.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_audio_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings for audio processing. For PixArt-Sigma this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
decode_timestep (`float`, defaults to `0.0`):
@@ -812,10 +878,8 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
# 3. Prepare text embeddings
(
prompt_embeds,
audio_prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_audio_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt=prompt,
@@ -823,9 +887,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
audio_prompt_embeds=audio_prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
negative_audio_prompt_embeds=negative_audio_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length,
@@ -833,9 +895,13 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
audio_prompt_embeds = torch.cat([negative_audio_prompt_embeds, audio_prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0
connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors(
prompt_embeds, additive_attention_mask, additive_mask=True
)
# 4. Prepare latent variables
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
@@ -939,11 +1005,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
noise_pred_video, noise_pred_audio = self.transformer(
hidden_states=latent_model_input,
audio_hidden_states=audio_latent_model_input,
encoder_hidden_states=prompt_embeds,
audio_encoder_hidden_states=audio_prompt_embeds,
encoder_hidden_states=connector_prompt_embeds,
audio_encoder_hidden_states=connector_audio_prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
audio_encoder_attention_mask=prompt_attention_mask,
encoder_attention_mask=connector_attention_mask,
audio_encoder_attention_mask=connector_attention_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,

View File

@@ -23,14 +23,14 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from .connectors import LTX2TextConnectors
from .pipeline_output import LTX2PipelineOutput
from ..pipeline_utils import DiffusionPipeline
from .text_encoder import LTX2AudioVisualTextEncoder
from .vocoder import LTX2Vocoder
from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video
from ...models.transformers import LTX2VideoTransformer3DModel
from transformers import GemmaTokenizer, GemmaTokenizerFast
from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
from ...video_processor import VideoProcessor
@@ -196,7 +196,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
TODO
"""
model_cpu_offload_seq = "text_encoder->transformer->vae->audio_vae->vocoder"
model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
@@ -205,8 +205,9 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKLLTX2Video,
audio_vae: AutoencoderKLLTX2Audio,
text_encoder: LTX2AudioVisualTextEncoder,
text_encoder: Gemma3ForConditionalGeneration,
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
connectors: LTX2TextConnectors,
transformer: LTX2VideoTransformer3DModel,
vocoder: LTX2Vocoder,
):
@@ -217,6 +218,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
audio_vae=audio_vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
connectors=connectors,
transformer=transformer,
vocoder=vocoder,
scheduler=scheduler,
@@ -254,6 +256,74 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024
)
@staticmethod
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds
def _pack_text_embeds(
text_hidden_states: torch.Tensor,
sequence_lengths: torch.Tensor,
device: Union[str, torch.device],
padding_side: str = "left",
scale_factor: int = 8,
eps: float = 1e-6,
) -> torch.Tensor:
"""
Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and
per-layer in a masked fashion (only over non-padded positions).
Args:
text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
The number of valid (non-padded) tokens for each batch instance.
device: (`str` or `torch.device`, *optional*):
torch device to place the resulting embeddings on
padding_side: (`str`, *optional*, defaults to `"left"`):
Whether the text tokenizer performs padding on the `"left"` or `"right"`.
scale_factor (`int`, *optional*, defaults to `8`):
Scaling factor to multiply the normalized hidden states by.
eps (`float`, *optional*, defaults to `1e-6`):
A small positive value for numerical stability when performing normalization.
Returns:
`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
Normed and flattened text encoder hidden states.
"""
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
original_dtype = text_hidden_states.dtype
# Create padding mask
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
if padding_side == "right":
# For right padding, valid tokens are from 0 to sequence_length-1
mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len]
elif padding_side == "left":
# For left padding, valid tokens are from (T - sequence_length) to T-1
start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1]
mask = token_indices >= start_indices # [B, T]
else:
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1]
# Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len)
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
# Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len)
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
# Normalization
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
normalized_hidden_states = normalized_hidden_states * scale_factor
# Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers)
normalized_hidden_states = normalized_hidden_states.flatten(2)
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
return normalized_hidden_states
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds
def _get_gemma_prompt_embeds(
self,
@@ -277,7 +347,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt.
"""
device = device or self._execution_device
dtype = dtype or self.text_encoder.base_text_encoder.dtype
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
@@ -299,29 +369,34 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
text_input_ids = text_input_ids.to(device)
prompt_attention_mask = prompt_attention_mask.to(device)
prompt_embeds, audio_prompt_embeds, prompt_attention_mask = self.text_encoder(
text_input_ids.to(device),
attention_mask=prompt_attention_mask.to(device),
text_encoder_outputs = self.text_encoder(
input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
)
text_encoder_hidden_states = text_encoder_outputs.hidden_states
text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
sequence_lengths = prompt_attention_mask.sum(dim=-1)
prompt_embeds = self._pack_text_embeds(
text_encoder_hidden_states,
sequence_lengths,
device=device,
padding_side=self.tokenizer.padding_side,
scale_factor=scale_factor,
)
prompt_embeds = prompt_embeds.to(dtype=dtype)
audio_prompt_embeds = audio_prompt_embeds.to(dtype=dtype)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
_, audio_seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, audio_seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
return prompt_embeds, audio_prompt_embeds, prompt_attention_mask
return prompt_embeds, prompt_attention_mask
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.encode_prompt
def encode_prompt(
@@ -331,9 +406,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
audio_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_audio_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 1024,
@@ -376,7 +449,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds, audio_prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
@@ -401,7 +474,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
" the batch size of `prompt`."
)
negative_prompt_embeds, negative_audio_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
@@ -410,7 +483,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
dtype=dtype,
)
return prompt_embeds, audio_prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_audio_prompt_embeds, negative_prompt_attention_mask
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.check_inputs
def check_inputs(
@@ -727,10 +800,8 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
latents: Optional[torch.Tensor] = None,
audio_latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
audio_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_audio_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
decode_timestep: Union[float, List[float]] = 0.0,
decode_noise_scale: Optional[Union[float, List[float]]] = None,
@@ -793,17 +864,11 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
audio_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings for audio processing. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Pre-generated attention mask for text embeddings.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_audio_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings for audio processing. For PixArt-Sigma this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
decode_timestep (`float`, defaults to `0.0`):
@@ -873,10 +938,8 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
# 3. Prepare text embeddings
(
prompt_embeds,
audio_prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_audio_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt=prompt,
@@ -884,9 +947,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
audio_prompt_embeds=audio_prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
negative_audio_prompt_embeds=negative_audio_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length,
@@ -894,9 +955,13 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
audio_prompt_embeds = torch.cat([negative_audio_prompt_embeds, audio_prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0
connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors(
prompt_embeds, additive_attention_mask, additive_mask=True
)
# 4. Prepare latent variables
if latents is None:
image = self.video_processor.preprocess(image, height=height, width=width)
@@ -1008,12 +1073,12 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
noise_pred_video, noise_pred_audio = self.transformer(
hidden_states=latent_model_input,
audio_hidden_states=audio_latent_model_input,
encoder_hidden_states=prompt_embeds,
audio_encoder_hidden_states=audio_prompt_embeds,
encoder_hidden_states=connector_prompt_embeds,
audio_encoder_hidden_states=connector_audio_prompt_embeds,
timestep=video_timestep,
audio_timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
audio_encoder_attention_mask=prompt_attention_mask,
encoder_attention_mask=connector_attention_mask,
audio_encoder_attention_mask=connector_attention_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,

View File

@@ -1,625 +0,0 @@
# Copyright 2025 The Lightricks team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoModel, Gemma3ForConditionalGeneration
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ...models.attention_dispatch import dispatch_attention_fn
from ...models.embeddings import get_1d_rotary_pos_embed
from ...models.modeling_utils import ModelMixin
from ...utils import is_torch_version, logging
from ..pipeline_loading_utils import _fetch_class_library_tuple
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def apply_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
cos, sin = freqs
x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
# Copied from diffusers.models.transformers.transformer_ltx2.LTX2AudioVideoAttnProcessor
class LTX2AudioVideoAttnProcessor:
r"""
Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0) for the LTX-2.0 model.
Compared to the LTX-1.0 model, we allow the RoPE embeddings for the queries and keys to be separate so that we can
support audio-to-video (a2v) and video-to-audio (v2a) cross attention.
"""
_attention_backend = None
_parallel_config = None
def __init__(self):
if is_torch_version("<", "2.0"):
raise ValueError(
"LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
)
def __call__(
self,
attn: "LTX2Attention",
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.norm_q(query)
key = attn.norm_k(key)
if query_rotary_emb is not None:
query = apply_rotary_emb(query, query_rotary_emb)
key = apply_rotary_emb(key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
# Copied from diffusers.models.transformers.transformer_ltx2.LTX2Attention
class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
r"""
Attention class for all LTX-2.0 attention layers. Compared to LTX-1.0, this supports specifying the query and key
RoPE embeddings separately for audio-to-video (a2v) and video-to-audio (v2a) cross-attention.
"""
_default_processor_cls = LTX2AudioVideoAttnProcessor
_available_processors = [LTX2AudioVideoAttnProcessor]
def __init__(
self,
query_dim: int,
heads: int = 8,
kv_heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = True,
cross_attention_dim: Optional[int] = None,
out_bias: bool = True,
qk_norm: str = "rms_norm_across_heads",
norm_eps: float = 1e-6,
norm_elementwise_affine: bool = True,
processor=None,
):
super().__init__()
if qk_norm != "rms_norm_across_heads":
raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.")
self.head_dim = dim_head
self.inner_dim = dim_head * heads
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
self.query_dim = query_dim
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.use_bias = bias
self.dropout = dropout
self.out_dim = query_dim
self.heads = heads
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
self.to_out = torch.nn.ModuleList([])
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(torch.nn.Dropout(dropout))
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
if len(unused_kwargs) > 0:
logger.warning(
f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
hidden_states = self.processor(
self, hidden_states, encoder_hidden_states, attention_mask, query_rotary_emb, key_rotary_emb, **kwargs
)
return hidden_states
class LTX2RotaryPosEmbed1d(nn.Module):
"""
1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors.
"""
def __init__(
self,
dim: int,
base_seq_len: int = 4096,
theta: float = 10000.0,
double_precision: bool = True,
):
super().__init__()
self.dim = dim
self.base_seq_len = base_seq_len
self.theta = theta
self.double_precision = double_precision
def forward(
self,
batch_size: int,
pos: int,
device: Union[str, torch.device],
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Get 1D position ids
grid_1d = torch.arange(pos, dtype=torch.float32, device=device)
# Get fractional indices relative to self.base_seq_len
grid_1d = grid_1d / self.base_seq_len
grid = grid_1d.unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len]
# 2. Calculate 1D RoPE frequencies
num_rope_elems = 2 # 1 (because 1D) * 2 (for cos, sin) = 2
start = 1.0
end = self.theta
if self.double_precision:
pow_indices = np.power(
self.theta,
np.linspace(
np.log(start) / np.log(self.theta),
np.log(end) / np.log(self.theta),
self.dim // num_rope_elems,
dtype=np.float64,
),
)
freqs = torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32, device=device)
else:
freqs = self.theta ** torch.linspace(
start=math.log(start, self.theta),
end=math.log(end, self.theta),
steps=self.dim // num_rope_elems,
device=device,
dtype=torch.float32,
)
freqs = freqs * math.pi / 2.0
# 3. Matrix-vector outer product between pos ids of shape (batch_size, seq_len) and freqs vector of shape
# (self.dim // 2,).
freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, seq_len, self.dim // 2]
# 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim
cos_freqs = freqs.cos().repeat_interleave(2, dim=-1)
sin_freqs = freqs.sin().repeat_interleave(2, dim=-1)
if self.dim % num_rope_elems != 0:
cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems])
sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems])
cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
return cos_freqs, sin_freqs
class LTX2TransformerBlock1d(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
activation_fn: str = "gelu-approximate",
eps: float = 1e-6,
):
super().__init__()
self.norm1 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
self.attn1 = LTX2Attention(
query_dim=dim,
heads=num_attention_heads,
kv_heads=num_attention_heads,
dim_head=attention_head_dim,
processor=LTX2AudioVideoAttnProcessor(),
)
self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
self.ff = FeedForward(dim, activation_fn=activation_fn)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
norm_hidden_states = self.norm1(hidden_states)
attn_hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, query_rotary_emb=rotary_emb)
hidden_states = hidden_states + attn_hidden_states
norm_hidden_states = self.norm2(hidden_states)
ff_hidden_states = self.ff(norm_hidden_states)
hidden_states = hidden_states + ff_hidden_states
return hidden_states
class LTX2ConnectorTransformer1d(nn.Module):
"""
A 1D sequence transformer for modalities such as text.
In LTX 2.0, this is used to process the text encoder hidden states for each of the video and audio streams.
"""
_supports_gradient_checkpointing = True
def __init__(
self,
num_attention_heads: int = 30,
attention_head_dim: int = 128,
num_layers: int = 2,
num_learnable_registers: Optional[int] = 128,
rope_base_seq_len: int = 4096,
rope_theta: float = 10000.0,
rope_double_precision: bool = True,
eps: float = 1e-6,
causal_temporal_positioning: bool = False,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.inner_dim = num_attention_heads * attention_head_dim
self.causal_temporal_positioning = causal_temporal_positioning
self.num_learnable_registers = num_learnable_registers
self.learnable_registers = None
if num_learnable_registers is not None:
init_registers = torch.rand(num_learnable_registers, self.inner_dim) * 2.0 - 1.0
self.learnable_registers = torch.nn.Parameter(init_registers)
self.rope = LTX2RotaryPosEmbed1d(
self.inner_dim, base_seq_len=rope_base_seq_len, theta=rope_theta, double_precision=rope_double_precision
)
self.transformer_blocks = torch.nn.ModuleList(
[
LTX2TransformerBlock1d(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for _ in range(num_layers)
]
)
self.norm_out = torch.nn.RMSNorm(self.inner_dim, eps=eps, elementwise_affine=False)
self.gradient_checkpointing = False
def forward(
self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# hidden_states shape: [batch_size, seq_len, hidden_dim]
# attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len]
batch_size, seq_len, _ = hidden_states.shape
# 1. Replace padding with learned registers, if using
if self.learnable_registers is not None:
if seq_len % self.num_learnable_registers != 0:
raise ValueError(
f"The `hidden_states` sequence length {hidden_states.shape[1]} should be divisible by the number"
f" of learnable registers {self.num_learnable_registers}"
)
num_register_repeats = seq_len // self.num_learnable_registers
registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim]
binary_attn_mask = (attention_mask >= -9000.0).int()
if binary_attn_mask.ndim == 4:
binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L]
hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)]
valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded]
pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens]
padded_hidden_states = [
F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths)
]
padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D]
flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1]
hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers
# Overwrite attention_mask with an all-zeros mask if using registers.
attention_mask = torch.zeros_like(attention_mask)
# 2. Calculate 1D RoPE positional embeddings
rotary_emb = self.rope(batch_size, seq_len, device=hidden_states.device)
# 3. Run 1D transformer blocks
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(block, hidden_states, attention_mask, rotary_emb)
else:
hidden_states = block(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb)
hidden_states = self.norm_out(hidden_states)
return hidden_states, attention_mask
class LTX2AudioVisualTextEncoder(ModelMixin, ConfigMixin):
ignore_for_config = ["text_model"]
@register_to_config
def __init__(
self,
text_model: Optional[Gemma3ForConditionalGeneration] = None,
text_model_id: str = "google/gemma-3-12b-it-qat-q4_0-unquantized",
text_encoder_hidden_dim: Optional[int] = 3840,
text_proj_in_factor: Optional[int] = 49, # Num layers in text encoder + 1
video_connector_num_attention_heads: int = 30,
video_connector_attention_head_dim: int = 128,
video_connector_num_layers: int = 2,
video_connector_num_learnable_registers: int = 128,
audio_connector_num_attention_heads: int = 30,
audio_connector_attention_head_dim: int = 128,
audio_connector_num_layers: int = 2,
audio_connector_num_learnable_registers: Optional[int] = 128,
rope_base_seq_len: int = 4096,
rope_theta: float = 10000.0,
rope_double_precision: bool = True,
causal_temporal_positioning: bool = False,
config_only: bool = True,
):
super().__init__()
if text_model is None:
self.set_base_text_encoder(text_model_id, config_only=config_only)
else:
self.base_text_encoder = text_model
if text_encoder_hidden_dim is None:
if hasattr(self.base_text_encoder, "config"):
if hasattr(self.base_text_encoder.config, "hidden_size"):
text_encoder_hidden_dim = getattr(self.base_text_encoder.config, "hidden_size", None)
elif hasattr(self.base_text_encoder.config, "text_config"):
text_encoder_hidden_dim = getattr(self.base_text_encoder.config.text_config, "hidden_size", None)
if text_encoder_hidden_dim is None:
raise ValueError(
"`text_encoder_hidden_dim` is `None` and it cannot be inferred, please provide a value for it."
)
if text_proj_in_factor is None:
num_layers = None
if hasattr(self.base_text_encoder, "config"):
if hasattr(self.base_text_encoder.config, "num_hidden_layers"):
num_layers = getattr(self.base_text_encoder.config, "num_hidden_layers", None)
elif hasattr(self.base_text_encoder.config, "text_config"):
num_layers = getattr(self.base_text_encoder.config.text_config, "num_hidden_layers", None)
if num_layers is None:
raise ValueError(
"`text_proj_in_factor` is `None` and it cannot be inferred, please provide a value for it."
)
text_proj_in_factor = num_layers + 1
self.text_proj_in = nn.Linear(
text_encoder_hidden_dim * text_proj_in_factor, text_encoder_hidden_dim, bias=False
)
self.video_connector = LTX2ConnectorTransformer1d(
num_attention_heads=video_connector_num_attention_heads,
attention_head_dim=video_connector_attention_head_dim,
num_layers=video_connector_num_layers,
num_learnable_registers=video_connector_num_learnable_registers,
rope_base_seq_len=rope_base_seq_len,
rope_theta=rope_theta,
rope_double_precision=rope_double_precision,
causal_temporal_positioning=causal_temporal_positioning,
)
self.audio_connector = LTX2ConnectorTransformer1d(
num_attention_heads=audio_connector_num_attention_heads,
attention_head_dim=audio_connector_attention_head_dim,
num_layers=audio_connector_num_layers,
num_learnable_registers=audio_connector_num_learnable_registers,
rope_base_seq_len=rope_base_seq_len,
rope_theta=rope_theta,
rope_double_precision=rope_double_precision,
causal_temporal_positioning=causal_temporal_positioning,
)
def set_base_text_encoder(
self, base_text_encoder_id: str = "google/gemma-3-12b-it-qat-q4_0-unquantized", config_only: bool = True
):
if config_only:
base_text_encoder_config = AutoConfig.from_pretrained(base_text_encoder_id)
base_text_encoder = AutoModel.from_config(base_text_encoder_config)
else:
base_text_encoder = AutoModel.from_pretrained(base_text_encoder_id)
self.base_text_encoder = base_text_encoder
@staticmethod
def pack_text_embeds(
text_hidden_states: torch.Tensor,
sequence_lengths: torch.Tensor,
device: Union[str, torch.device],
padding_side: str = "left",
scale_factor: int = 8,
eps: float = 1e-6,
) -> torch.Tensor:
"""
Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and
per-layer in a masked fashion (only over non-padded positions).
Args:
text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
The number of valid (non-padded) tokens for each batch instance.
device: (`str` or `torch.device`, *optional*):
torch device to place the resulting embeddings on
padding_side: (`str`, *optional*, defaults to `"left"`):
Whether the text tokenizer performs padding on the `"left"` or `"right"`.
scale_factor (`int`, *optional*, defaults to `8`):
Scaling factor to multiply the normalized hidden states by.
eps (`float`, *optional*, defaults to `1e-6`):
A small positive value for numerical stability when performing normalization.
Returns:
`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
Normed and flattened text encoder hidden states.
"""
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
original_dtype = text_hidden_states.dtype
# Create padding mask
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
if padding_side == "right":
# For right padding, valid tokens are from 0 to sequence_length-1
mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len]
elif padding_side == "left":
# For left padding, valid tokens are from (T - sequence_length) to T-1
start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1]
mask = token_indices >= start_indices # [B, T]
else:
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1]
# Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len)
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
# Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len)
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
# Normalization
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
normalized_hidden_states = normalized_hidden_states * scale_factor
# Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers)
normalized_hidden_states = normalized_hidden_states.flatten(2)
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
return normalized_hidden_states
def run_connectors(
self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Run LTX 2.0-specific text embedding post-processing logic on top of the base text encoder hidden_states.
Args:
text_encoder_hidden_states (`torch.Tensor`):
Text encoder packed hidden_states of shape `(batch_size, seq_len, hidden_dim * (num_layers + 1))`.
attention_mask (`torch.Tensor`):
Attention mask of shape `(batch_size, seq_len)`.
Returns:
`Tuple(torch.Tensor, torch.Tensor, torch.Tensor)]`:
Returns a 3-tuple of tensors where the first element is the video text embeddings of shape
`(batch_size, seq_len, hidden_dim)`, the second element is the audio text embeddings of shape
`(batch_size, seq_len, hidden_dim)`, and the third element is an attention mask of shape
`(batch_size, seq_len)`.
"""
# Convert to additive attention mask
text_dtype = text_encoder_hidden_states.dtype
connector_attn_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
connector_attn_mask = connector_attn_mask.to(text_dtype) * torch.finfo(text_dtype).max
text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states)
video_text_embedding, new_attn_mask = self.video_connector(
text_encoder_hidden_states, connector_attn_mask
)
attn_mask = (new_attn_mask < 1e-6).to(torch.int64)
attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1)
video_text_embedding = video_text_embedding * attn_mask
new_attn_mask = attn_mask.squeeze(-1)
audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, connector_attn_mask)
return video_text_embedding, audio_text_embedding, new_attn_mask
def forward(
self,
text_input_ids,
attention_mask: Optional[torch.Tensor] = None,
padding_side: str = "left",
scale_factor: int = 8,
):
text_encoder_outputs = self.base_text_encoder(
input_ids=text_input_ids, attention_mask=attention_mask, output_hidden_states=True
)
text_encoder_hidden_states = text_encoder_outputs.hidden_states
text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
sequence_lengths = attention_mask.sum(dim=-1)
text_encoder_hidden_states = self.pack_text_embeds(
text_encoder_hidden_states,
sequence_lengths,
device=text_encoder_hidden_states.device,
padding_side=padding_side,
scale_factor=scale_factor,
)
video_text_embedding, audio_text_embedding, new_attn_mask = self.run_connectors(
text_encoder_hidden_states, attention_mask
)
return video_text_embedding, audio_text_embedding, new_attn_mask

View File

@@ -99,6 +99,7 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
"num_layers": 2,
"qk_norm": "rms_norm_across_heads",
"caption_channels": 16,
"rope_double_precision": False,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict