From caae16768a240def1a366d8173d1c4e825bfc5c8 Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Mon, 5 Jan 2026 06:41:13 -0800 Subject: [PATCH] Move Video and Audio Text Encoder Connectors to Transformer (#12) * Denormalize audio latents in I2V pipeline (analogous to T2V change) * Initial refactor to put video and audio text encoder connectors in transformer * Get LTX 2 transformer tests working after connector refactor * precompute run_connectors,. * fixes * Address review comments * Calculate RoPE double precisions freqs using torch instead of np * Further simplify LTX 2 RoPE freq calc * Make connectors a separate module (#18) * remove text_encoder.py * address yiyi's comments. * up * up * up * up --------- Co-authored-by: sayakpaul --- scripts/convert_ltx2_to_diffusers.py | 230 ++++--- .../models/transformers/transformer_ltx2.py | 30 +- src/diffusers/pipelines/ltx2/__init__.py | 4 +- src/diffusers/pipelines/ltx2/connectors.py | 281 ++++++++ src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 136 +++- .../ltx2/pipeline_ltx2_image2video.py | 137 +++- src/diffusers/pipelines/ltx2/text_encoder.py | 625 ------------------ .../test_models_transformer_ltx2.py | 1 + 8 files changed, 629 insertions(+), 815 deletions(-) create mode 100644 src/diffusers/pipelines/ltx2/connectors.py delete mode 100644 src/diffusers/pipelines/ltx2/text_encoder.py diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index eb311c3bc0..9f58d8f344 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -8,18 +8,11 @@ import safetensors.torch import torch from accelerate import init_empty_weights from huggingface_hub import hf_hub_download -from transformers import AutoModel, AutoTokenizer +from transformers import AutoModel, AutoTokenizer, Gemma3ForConditionalGeneration -from diffusers import ( - AutoencoderKLLTX2Audio, - AutoencoderKLLTX2Video, - FlowMatchEulerDiscreteScheduler, - LTX2Pipeline, - LTX2VideoTransformer3DModel, -) +from diffusers import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, FlowMatchEulerDiscreteScheduler, LTX2Pipeline, LTX2VideoTransformer3DModel +from diffusers.pipelines.ltx2 import LTX2TextConnectors, LTX2Vocoder from diffusers.utils.import_utils import is_accelerate_available -from diffusers.pipelines.ltx2.text_encoder import LTX2AudioVisualTextEncoder -from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder CTX = init_empty_weights if is_accelerate_available() else nullcontext @@ -134,6 +127,17 @@ LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = { "adaln_single": convert_ltx2_transformer_adaln_single, } +LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = { + "connectors.": "", + "video_embeddings_connector": "video_connector", + "audio_embeddings_connector": "audio_connector", + "transformer_1d_blocks": "transformer_blocks", + "text_embedding_projection.aggregate_embed": "text_proj_in", + # Attention QK Norms + "q_norm": "norm_q", + "k_norm": "norm_k", +} + LTX_2_0_VAE_SPECIAL_KEYS_REMAP = { "per_channel_statistics.channel": remove_keys_inplace, "per_channel_statistics.mean-of-stds": remove_keys_inplace, @@ -146,7 +150,27 @@ LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = { LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {} -LTX_2_0_TEXT_ENCODER_SPECIAL_KEYS_REMAP = {} + +def split_transformer_and_connector_state_dict(state_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + connector_prefixes = ( + "video_embeddings_connector", + "audio_embeddings_connector", + "transformer_1d_blocks", + "text_embedding_projection.aggregate_embed", + "connectors.", + "video_connector", + "audio_connector", + "text_proj_in", + ) + + transformer_state_dict, connector_state_dict = {}, {} + for key, value in state_dict.items(): + if key.startswith(connector_prefixes): + connector_state_dict[key] = value + else: + transformer_state_dict[key] = value + + return transformer_state_dict, connector_state_dict def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: @@ -240,32 +264,109 @@ def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, return config, rename_dict, special_keys_remap +def get_ltx2_connectors_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + if version == "test": + config = { + "model_id": "diffusers-internal-dev/dummy-ltx2", + "diffusers_config": { + "caption_channels": 16, + "text_proj_in_factor": 3, + "video_connector_num_attention_heads": 4, + "video_connector_attention_head_dim": 8, + "video_connector_num_layers": 1, + "video_connector_num_learnable_registers": None, + "audio_connector_num_attention_heads": 4, + "audio_connector_attention_head_dim": 8, + "audio_connector_num_layers": 1, + "audio_connector_num_learnable_registers": None, + "connector_rope_base_seq_len": 32, + "rope_theta": 10000.0, + "rope_double_precision": False, + "causal_temporal_positioning": False, + }, + } + elif version == "2.0": + config = { + "model_id": "diffusers-internal-dev/new-ltx-model", + "diffusers_config": { + "caption_channels": 3840, + "text_proj_in_factor": 49, + "video_connector_num_attention_heads": 30, + "video_connector_attention_head_dim": 128, + "video_connector_num_layers": 2, + "video_connector_num_learnable_registers": 128, + "audio_connector_num_attention_heads": 30, + "audio_connector_attention_head_dim": 128, + "audio_connector_num_layers": 2, + "audio_connector_num_learnable_registers": 128, + "connector_rope_base_seq_len": 4096, + "rope_theta": 10000.0, + "rope_double_precision": True, + "causal_temporal_positioning": False, + }, + } + + rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT + special_keys_remap = {} + + return config, rename_dict, special_keys_remap + + def convert_ltx2_transformer(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]: config, rename_dict, special_keys_remap = get_ltx2_transformer_config(version) diffusers_config = config["diffusers_config"] + transformer_state_dict, _ = split_transformer_and_connector_state_dict(original_state_dict) + with init_empty_weights(): transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config) # Handle official code --> diffusers key remapping via the remap dict - for key in list(original_state_dict.keys()): + for key in list(transformer_state_dict.keys()): new_key = key[:] for replace_key, rename_key in rename_dict.items(): new_key = new_key.replace(replace_key, rename_key) - update_state_dict_inplace(original_state_dict, key, new_key) + update_state_dict_inplace(transformer_state_dict, key, new_key) # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in # special_keys_remap - for key in list(original_state_dict.keys()): + for key in list(transformer_state_dict.keys()): for special_key, handler_fn_inplace in special_keys_remap.items(): if special_key not in key: continue - handler_fn_inplace(key, original_state_dict) + handler_fn_inplace(key, transformer_state_dict) - transformer.load_state_dict(original_state_dict, strict=True, assign=True) + transformer.load_state_dict(transformer_state_dict, strict=True, assign=True) return transformer +def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -> LTX2TextConnectors: + config, rename_dict, special_keys_remap = get_ltx2_connectors_config(version) + diffusers_config = config["diffusers_config"] + + _, connector_state_dict = split_transformer_and_connector_state_dict(original_state_dict) + if len(connector_state_dict) == 0: + raise ValueError("No connector weights found in the provided state dict.") + + with init_empty_weights(): + connectors = LTX2TextConnectors.from_config(diffusers_config) + + for key in list(connector_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in rename_dict.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(connector_state_dict, key, new_key) + + for key in list(connector_state_dict.keys()): + for special_key, handler_fn_inplace in special_keys_remap.items(): + if special_key not in key: + continue + handler_fn_inplace(key, connector_state_dict) + + connectors.load_state_dict(connector_state_dict, strict=True, assign=True) + return connectors + + def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: if version == "test": config = { @@ -471,81 +572,6 @@ def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> D return vocoder -def get_ltx2_text_encoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: - if version == "2.0": - config = { - "model_id": "diffusers-internal-dev/new-ltx-model", - "diffusers_config": { - "text_encoder_hidden_dim": 3840, - "text_proj_in_factor": 49, - "video_connector_num_attention_heads": 30, - "video_connector_attention_head_dim": 128, - "video_connector_num_layers": 2, - "video_connector_num_learnable_registers": 128, - "audio_connector_num_attention_heads": 30, - "audio_connector_attention_head_dim": 128, - "audio_connector_num_layers": 2, - "audio_connector_num_learnable_registers": 128, - "rope_base_seq_len": 4096, - "rope_theta": 10000.0, - "rope_double_precision": True, - "causal_temporal_positioning": False, - }, - } - rename_dict = LTX_2_0_TEXT_ENCODER_RENAME_DICT - special_keys_remap = LTX_2_0_TEXT_ENCODER_SPECIAL_KEYS_REMAP - return config, rename_dict, special_keys_remap - - -def get_text_encoder_keys_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str = "model.diffusion_model."): - model_state_dict = {} - - model_state_dict["text_proj_in.weight"] = combined_ckpt["text_embedding_projection.aggregate_embed.weight"] - - text_encoder_submodules = ["video_embeddings_connector", "audio_embeddings_connector"] - for param_name, param in combined_ckpt.items(): - if param_name.startswith(prefix): - new_param_name = param_name.replace(prefix, "") - for submodule_name in text_encoder_submodules: - if new_param_name.startswith(submodule_name): - model_state_dict[new_param_name] = param - break - - return model_state_dict - - -def convert_ltx2_text_encoder(original_state_dict: Dict[str, Any], version: str, text_model_id: str) -> Dict[str, Any]: - config, rename_dict, special_keys_remap = get_ltx2_text_encoder_config(version) - diffusers_config = config["diffusers_config"] - diffusers_config["text_model_id"] = text_model_id - diffusers_config["config_only"] = True - - with init_empty_weights(): - text_encoder = LTX2AudioVisualTextEncoder.from_config(diffusers_config) - - # Handle official code --> diffusers key remapping via the remap dict - for key in list(original_state_dict.keys()): - new_key = key[:] - for replace_key, rename_key in rename_dict.items(): - new_key = new_key.replace(replace_key, rename_key) - update_state_dict_inplace(original_state_dict, key, new_key) - - # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in - # special_keys_remap - for key in list(original_state_dict.keys()): - for special_key, handler_fn_inplace in special_keys_remap.items(): - if special_key not in key: - continue - handler_fn_inplace(key, original_state_dict) - - base_text_model = AutoModel.from_pretrained(text_model_id) - base_text_model_state_dict= base_text_model.state_dict() - base_text_model_state_dict = {"base_text_encoder." + k: v for k, v in base_text_model_state_dict.items()} - combined_state_dict = {**original_state_dict, **base_text_model_state_dict} - - text_encoder.load_state_dict(combined_state_dict, strict=True, assign=True) - return text_encoder - def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]: if args.original_state_dict_repo_id is not None: @@ -588,6 +614,13 @@ def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefi for param_name, param in combined_ckpt.items(): if param_name.startswith(prefix): model_state_dict[param_name.replace(prefix, "")] = param + + if prefix == "model.diffusion_model.": + # Some checkpoints store the text connector projection outside the diffusion model prefix. + connector_key = "text_embedding_projection.aggregate_embed.weight" + if connector_key in combined_ckpt and connector_key not in model_state_dict: + model_state_dict[connector_key] = combined_ckpt[connector_key] + return model_state_dict @@ -649,6 +682,7 @@ def get_args(): parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model") parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model") parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model") + parser.add_argument("--connectors", action="store_true", help="Whether to convert the connector model") parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model") parser.add_argument("--text_encoder", action="store_true", help="Whether to conver the text encoder") parser.add_argument( @@ -721,6 +755,15 @@ def main(args): transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version) if not args.full_pipeline: transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer")) + + if args.connectors or args.full_pipeline: + if args.dit_filename is not None: + original_connectors_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename) + elif combined_ckpt is not None: + original_connectors_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix) + connectors = convert_ltx2_connectors(original_connectors_ckpt, version=args.version) + if not args.full_pipeline: + connectors.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "connectors")) if args.vocoder or args.full_pipeline: if args.vocoder_filename is not None: @@ -732,8 +775,8 @@ def main(args): vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder")) if args.text_encoder or args.full_pipeline: - text_encoder_ckpt = get_text_encoder_keys_from_combined_ckpt(combined_ckpt) - text_encoder = convert_ltx2_text_encoder(text_encoder_ckpt, args.version, args.text_encoder_model_id) + # text_encoder = AutoModel.from_pretrained(args.text_encoder_model_id) + text_encoder = Gemma3ForConditionalGeneration.from_pretrained(args.text_encoder_model_id) if not args.full_pipeline: text_encoder.to(text_encoder_dtype).save_pretrained(os.path.join(args.output_path, "text_encoder")) @@ -758,6 +801,7 @@ def main(args): audio_vae=audio_vae, text_encoder=text_encoder, tokenizer=tokenizer, + connectors=connectors, transformer=transformer, vocoder=vocoder, ) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index 1f685fdc3a..d0e5da2390 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -14,11 +14,9 @@ # limitations under the License. import inspect -import math from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union -import numpy as np import torch import torch.nn as nn @@ -780,28 +778,12 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module): num_rope_elems = num_pos_dims * 2 # 4. Create a 1D grid of frequencies for RoPE - start = 1.0 - end = self.theta - if self.double_precision: - pow_indices = np.power( - self.theta, - np.linspace( - np.log(start) / np.log(self.theta), - np.log(end) / np.log(self.theta), - self.dim // num_rope_elems, - dtype=np.float64, - ), - ) - freqs = torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32, device=device) - else: - freqs = self.theta ** torch.linspace( - start=math.log(start, self.theta), - end=math.log(end, self.theta), - steps=self.dim // num_rope_elems, - device=device, - dtype=torch.float32, - ) - freqs = freqs * math.pi / 2.0 + freqs_dtype = torch.float64 if self.double_precision else torch.float32 + pow_indices = torch.pow( + self.theta, + torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device), + ) + freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32) # 5. Tensor-vector outer product between pos ids tensor of shape (B, 3, num_patches) and freqs vector of shape # (self.dim // num_elems,) diff --git a/src/diffusers/pipelines/ltx2/__init__.py b/src/diffusers/pipelines/ltx2/__init__.py index a97c836e0c..95d5f8d4a4 100644 --- a/src/diffusers/pipelines/ltx2/__init__.py +++ b/src/diffusers/pipelines/ltx2/__init__.py @@ -24,7 +24,7 @@ except OptionalDependencyNotAvailable: else: _import_structure["pipeline_ltx2"] = ["LTX2Pipeline"] _import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"] - _import_structure["text_encoder"] = ["LTX2AudioVisualTextEncoder"] + _import_structure["connectors"] = ["LTX2TextConnectors"] _import_structure["vocoder"] = ["LTX2Vocoder"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -37,7 +37,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: else: from .pipeline_ltx2 import LTX2Pipeline from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline - from .text_encoder import LTX2AudioVisualTextEncoder + from .connectors import LTX2TextConnectors from .vocoder import LTX2Vocoder else: diff --git a/src/diffusers/pipelines/ltx2/connectors.py b/src/diffusers/pipelines/ltx2/connectors.py new file mode 100644 index 0000000000..ce4dc4494f --- /dev/null +++ b/src/diffusers/pipelines/ltx2/connectors.py @@ -0,0 +1,281 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.attention import FeedForward +from ...models.modeling_utils import ModelMixin +from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor + + +class LTX2RotaryPosEmbed1d(nn.Module): + """ + 1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors. + """ + + def __init__( + self, + dim: int, + base_seq_len: int = 4096, + theta: float = 10000.0, + double_precision: bool = True, + ): + super().__init__() + self.dim = dim + self.base_seq_len = base_seq_len + self.theta = theta + self.double_precision = double_precision + + def forward( + self, + batch_size: int, + pos: int, + device: Union[str, torch.device], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Get 1D position ids + grid_1d = torch.arange(pos, dtype=torch.float32, device=device) + # Get fractional indices relative to self.base_seq_len + grid_1d = grid_1d / self.base_seq_len + grid = grid_1d.unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len] + + # 2. Calculate 1D RoPE frequencies + num_rope_elems = 2 # 1 (because 1D) * 2 (for cos, sin) = 2 + freqs_dtype = torch.float64 if self.double_precision else torch.float32 + pow_indices = torch.pow( + self.theta, + torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device), + ) + freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32) + + # 3. Matrix-vector outer product between pos ids of shape (batch_size, seq_len) and freqs vector of shape + # (self.dim // 2,). + freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, seq_len, self.dim // 2] + + # 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim + cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) + sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) + + if self.dim % num_rope_elems != 0: + cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems]) + cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) + sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) + + return cos_freqs, sin_freqs + + +class LTX2TransformerBlock1d(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + activation_fn: str = "gelu-approximate", + eps: float = 1e-6, + ): + super().__init__() + + self.norm1 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) + self.attn1 = LTX2Attention( + query_dim=dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + processor=LTX2AudioVideoAttnProcessor(), + ) + + self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) + self.ff = FeedForward(dim, activation_fn=activation_fn) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + norm_hidden_states = self.norm1(hidden_states) + attn_hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, query_rotary_emb=rotary_emb) + hidden_states = hidden_states + attn_hidden_states + + norm_hidden_states = self.norm2(hidden_states) + ff_hidden_states = self.ff(norm_hidden_states) + hidden_states = hidden_states + ff_hidden_states + + return hidden_states + + +class LTX2ConnectorTransformer1d(nn.Module): + """ + A 1D sequence transformer for modalities such as text. + + In LTX 2.0, this is used to process the text encoder hidden states for each of the video and audio streams. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + num_attention_heads: int = 30, + attention_head_dim: int = 128, + num_layers: int = 2, + num_learnable_registers: int | None = 128, + rope_base_seq_len: int = 4096, + rope_theta: float = 10000.0, + rope_double_precision: bool = True, + eps: float = 1e-6, + causal_temporal_positioning: bool = False, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self.causal_temporal_positioning = causal_temporal_positioning + + self.num_learnable_registers = num_learnable_registers + self.learnable_registers = None + if num_learnable_registers is not None: + init_registers = torch.rand(num_learnable_registers, self.inner_dim) * 2.0 - 1.0 + self.learnable_registers = torch.nn.Parameter(init_registers) + + self.rope = LTX2RotaryPosEmbed1d( + self.inner_dim, base_seq_len=rope_base_seq_len, theta=rope_theta, double_precision=rope_double_precision + ) + + self.transformer_blocks = torch.nn.ModuleList( + [ + LTX2TransformerBlock1d( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = torch.nn.RMSNorm(self.inner_dim, eps=eps, elementwise_affine=False) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + attn_mask_binarize_threshold: float = -9000.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # hidden_states shape: [batch_size, seq_len, hidden_dim] + # attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len] + batch_size, seq_len, _ = hidden_states.shape + + # 1. Replace padding with learned registers, if using + if self.learnable_registers is not None: + if seq_len % self.num_learnable_registers != 0: + raise ValueError( + f"The `hidden_states` sequence length {hidden_states.shape[1]} should be divisible by the number" + f" of learnable registers {self.num_learnable_registers}" + ) + + num_register_repeats = seq_len // self.num_learnable_registers + registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim] + + binary_attn_mask = (attention_mask >= attn_mask_binarize_threshold).int() + if binary_attn_mask.ndim == 4: + binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L] + + hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)] + valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded] + pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens] + padded_hidden_states = [ + F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths) + ] + padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D] + + flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1] + hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers + + # Overwrite attention_mask with an all-zeros mask if using registers. + attention_mask = torch.zeros_like(attention_mask) + + # 2. Calculate 1D RoPE positional embeddings + rotary_emb = self.rope(batch_size, seq_len, device=hidden_states.device) + + # 3. Run 1D transformer blocks + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(block, hidden_states, attention_mask, rotary_emb) + else: + hidden_states = block(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb) + + hidden_states = self.norm_out(hidden_states) + + return hidden_states, attention_mask + + +class LTX2TextConnectors(ModelMixin, ConfigMixin): + """ + Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and + audio streams. + """ + + @register_to_config + def __init__( + self, + caption_channels: int, + text_proj_in_factor: int, + video_connector_num_attention_heads: int, + video_connector_attention_head_dim: int, + video_connector_num_layers: int, + video_connector_num_learnable_registers: int | None, + audio_connector_num_attention_heads: int, + audio_connector_attention_head_dim: int, + audio_connector_num_layers: int, + audio_connector_num_learnable_registers: int | None, + connector_rope_base_seq_len: int, + rope_theta: float, + rope_double_precision: bool, + causal_temporal_positioning: bool, + ): + super().__init__() + self.text_proj_in = nn.Linear(caption_channels * text_proj_in_factor, caption_channels, bias=False) + self.video_connector = LTX2ConnectorTransformer1d( + num_attention_heads=video_connector_num_attention_heads, + attention_head_dim=video_connector_attention_head_dim, + num_layers=video_connector_num_layers, + num_learnable_registers=video_connector_num_learnable_registers, + rope_base_seq_len=connector_rope_base_seq_len, + rope_theta=rope_theta, + rope_double_precision=rope_double_precision, + causal_temporal_positioning=causal_temporal_positioning, + ) + self.audio_connector = LTX2ConnectorTransformer1d( + num_attention_heads=audio_connector_num_attention_heads, + attention_head_dim=audio_connector_attention_head_dim, + num_layers=audio_connector_num_layers, + num_learnable_registers=audio_connector_num_learnable_registers, + rope_base_seq_len=connector_rope_base_seq_len, + rope_theta=rope_theta, + rope_double_precision=rope_double_precision, + causal_temporal_positioning=causal_temporal_positioning, + ) + + def forward( + self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor, additive_mask: bool = False + ): + # Convert to additive attention mask, if necessary + if not additive_mask: + text_dtype = text_encoder_hidden_states.dtype + attention_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + attention_mask = attention_mask.to(text_dtype) * torch.finfo(text_dtype).max + + text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states) + + video_text_embedding, new_attn_mask = self.video_connector(text_encoder_hidden_states, attention_mask) + + attn_mask = (new_attn_mask < 1e-6).to(torch.int64) + attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1) + video_text_embedding = video_text_embedding * attn_mask + new_attn_mask = attn_mask.squeeze(-1) + + audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, attention_mask) + + return video_text_embedding, audio_text_embedding, new_attn_mask diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 2617e5cacb..08fad91c41 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -29,8 +29,8 @@ from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline +from .connectors import LTX2TextConnectors from .pipeline_output import LTX2PipelineOutput -from .text_encoder import LTX2AudioVisualTextEncoder from .vocoder import LTX2Vocoder @@ -192,9 +192,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix tokenizer (`T5TokenizerFast`): Second Tokenizer of class [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + connectors ([`LTX2TextConnectors`]): + Text connector stack used to adapt text encoder hidden states for the video and audio branches. """ - model_cpu_offload_seq = "text_encoder->transformer->vae->audio_vae->vocoder" + model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder" _optional_components = [] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] @@ -203,8 +205,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKLLTX2Video, audio_vae: AutoencoderKLLTX2Audio, - text_encoder: LTX2AudioVisualTextEncoder, + text_encoder: Gemma3ForConditionalGeneration, tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], + connectors: LTX2TextConnectors, transformer: LTX2VideoTransformer3DModel, vocoder: LTX2Vocoder, ): @@ -215,6 +218,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix audio_vae=audio_vae, text_encoder=text_encoder, tokenizer=tokenizer, + connectors=connectors, transformer=transformer, vocoder=vocoder, scheduler=scheduler, @@ -252,6 +256,73 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 ) + @staticmethod + def _pack_text_embeds( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device: Union[str, torch.device], + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, + ) -> torch.Tensor: + """ + Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and + per-layer in a masked fashion (only over non-padded positions). + + Args: + text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): + Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). + sequence_lengths (`torch.Tensor of shape `(batch_size,)`): + The number of valid (non-padded) tokens for each batch instance. + device: (`str` or `torch.device`, *optional*): + torch device to place the resulting embeddings on + padding_side: (`str`, *optional*, defaults to `"left"`): + Whether the text tokenizer performs padding on the `"left"` or `"right"`. + scale_factor (`int`, *optional*, defaults to `8`): + Scaling factor to multiply the normalized hidden states by. + eps (`float`, *optional*, defaults to `1e-6`): + A small positive value for numerical stability when performing normalization. + + Returns: + `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: + Normed and flattened text encoder hidden states. + """ + batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + # Create padding mask + token_indices = torch.arange(seq_len, device=device).unsqueeze(0) + if padding_side == "right": + # For right padding, valid tokens are from 0 to sequence_length-1 + mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] + elif padding_side == "left": + # For left padding, valid tokens are from (T - sequence_length) to T-1 + start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] + mask = token_indices >= start_indices # [B, T] + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] + + # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) + masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) + num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) + masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) + + # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + # Normalization + normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) + normalized_hidden_states = normalized_hidden_states * scale_factor + + # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) + normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) + return normalized_hidden_states + def _get_gemma_prompt_embeds( self, prompt: Union[str, List[str]], @@ -274,7 +345,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt. """ device = device or self._execution_device - dtype = dtype or self.text_encoder.base_text_encoder.dtype + dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) @@ -296,29 +367,34 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix ) text_input_ids = text_inputs.input_ids prompt_attention_mask = text_inputs.attention_mask + text_input_ids = text_input_ids.to(device) + prompt_attention_mask = prompt_attention_mask.to(device) - prompt_embeds, audio_prompt_embeds, prompt_attention_mask = self.text_encoder( - text_input_ids.to(device), - attention_mask=prompt_attention_mask.to(device), + text_encoder_outputs = self.text_encoder( + input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + text_encoder_hidden_states = text_encoder_outputs.hidden_states + text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) + sequence_lengths = prompt_attention_mask.sum(dim=-1) + + prompt_embeds = self._pack_text_embeds( + text_encoder_hidden_states, + sequence_lengths, + device=device, padding_side=self.tokenizer.padding_side, scale_factor=scale_factor, ) prompt_embeds = prompt_embeds.to(dtype=dtype) - audio_prompt_embeds = audio_prompt_embeds.to(dtype=dtype) # duplicate text embeddings for each generation per prompt, using mps friendly method _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - _, audio_seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, audio_seq_len, -1) - prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) - return prompt_embeds, audio_prompt_embeds, prompt_attention_mask + return prompt_embeds, prompt_attention_mask def encode_prompt( self, @@ -327,9 +403,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix do_classifier_free_guidance: bool = True, num_videos_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, - audio_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, - negative_audio_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, max_sequence_length: int = 1024, @@ -372,7 +446,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix batch_size = prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, audio_prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( prompt=prompt, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, @@ -397,7 +471,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix " the batch size of `prompt`." ) - negative_prompt_embeds, negative_audio_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( prompt=negative_prompt, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, @@ -406,7 +480,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix dtype=dtype, ) - return prompt_embeds, audio_prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_audio_prompt_embeds, negative_prompt_attention_mask + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask def check_inputs( self, @@ -668,10 +742,8 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix latents: Optional[torch.Tensor] = None, audio_latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, - audio_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, - negative_audio_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, decode_timestep: Union[float, List[float]] = 0.0, decode_noise_scale: Optional[Union[float, List[float]]] = None, @@ -732,17 +804,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. - audio_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings for audio processing. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. - negative_audio_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings for audio processing. For PixArt-Sigma this negative prompt should be "". If not - provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for negative text embeddings. decode_timestep (`float`, defaults to `0.0`): @@ -812,10 +878,8 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix # 3. Prepare text embeddings ( prompt_embeds, - audio_prompt_embeds, prompt_attention_mask, negative_prompt_embeds, - negative_audio_prompt_embeds, negative_prompt_attention_mask, ) = self.encode_prompt( prompt=prompt, @@ -823,9 +887,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix do_classifier_free_guidance=self.do_classifier_free_guidance, num_videos_per_prompt=num_videos_per_prompt, prompt_embeds=prompt_embeds, - audio_prompt_embeds=audio_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, - negative_audio_prompt_embeds=negative_audio_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, max_sequence_length=max_sequence_length, @@ -833,9 +895,13 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix ) if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - audio_prompt_embeds = torch.cat([negative_audio_prompt_embeds, audio_prompt_embeds], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( + prompt_embeds, additive_attention_mask, additive_mask=True + ) + # 4. Prepare latent variables latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 latent_height = height // self.vae_spatial_compression_ratio @@ -939,11 +1005,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix noise_pred_video, noise_pred_audio = self.transformer( hidden_states=latent_model_input, audio_hidden_states=audio_latent_model_input, - encoder_hidden_states=prompt_embeds, - audio_encoder_hidden_states=audio_prompt_embeds, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, timestep=timestep, - encoder_attention_mask=prompt_attention_mask, - audio_encoder_attention_mask=prompt_attention_mask, + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, num_frames=latent_num_frames, height=latent_height, width=latent_width, diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py index 359e665d4b..caad9a1767 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py @@ -23,14 +23,14 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor +from .connectors import LTX2TextConnectors from .pipeline_output import LTX2PipelineOutput from ..pipeline_utils import DiffusionPipeline -from .text_encoder import LTX2AudioVisualTextEncoder from .vocoder import LTX2Vocoder from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video from ...models.transformers import LTX2VideoTransformer3DModel -from transformers import GemmaTokenizer, GemmaTokenizerFast +from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast from ...video_processor import VideoProcessor @@ -196,7 +196,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL TODO """ - model_cpu_offload_seq = "text_encoder->transformer->vae->audio_vae->vocoder" + model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder" _optional_components = [] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] @@ -205,8 +205,9 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKLLTX2Video, audio_vae: AutoencoderKLLTX2Audio, - text_encoder: LTX2AudioVisualTextEncoder, + text_encoder: Gemma3ForConditionalGeneration, tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], + connectors: LTX2TextConnectors, transformer: LTX2VideoTransformer3DModel, vocoder: LTX2Vocoder, ): @@ -217,6 +218,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL audio_vae=audio_vae, text_encoder=text_encoder, tokenizer=tokenizer, + connectors=connectors, transformer=transformer, vocoder=vocoder, scheduler=scheduler, @@ -254,6 +256,74 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 ) + @staticmethod + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds + def _pack_text_embeds( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device: Union[str, torch.device], + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, + ) -> torch.Tensor: + """ + Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and + per-layer in a masked fashion (only over non-padded positions). + + Args: + text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): + Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). + sequence_lengths (`torch.Tensor of shape `(batch_size,)`): + The number of valid (non-padded) tokens for each batch instance. + device: (`str` or `torch.device`, *optional*): + torch device to place the resulting embeddings on + padding_side: (`str`, *optional*, defaults to `"left"`): + Whether the text tokenizer performs padding on the `"left"` or `"right"`. + scale_factor (`int`, *optional*, defaults to `8`): + Scaling factor to multiply the normalized hidden states by. + eps (`float`, *optional*, defaults to `1e-6`): + A small positive value for numerical stability when performing normalization. + + Returns: + `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: + Normed and flattened text encoder hidden states. + """ + batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + # Create padding mask + token_indices = torch.arange(seq_len, device=device).unsqueeze(0) + if padding_side == "right": + # For right padding, valid tokens are from 0 to sequence_length-1 + mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] + elif padding_side == "left": + # For left padding, valid tokens are from (T - sequence_length) to T-1 + start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] + mask = token_indices >= start_indices # [B, T] + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] + + # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) + masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) + num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) + masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) + + # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + # Normalization + normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) + normalized_hidden_states = normalized_hidden_states * scale_factor + + # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) + normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) + return normalized_hidden_states + # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds def _get_gemma_prompt_embeds( self, @@ -277,7 +347,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt. """ device = device or self._execution_device - dtype = dtype or self.text_encoder.base_text_encoder.dtype + dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) @@ -299,29 +369,34 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL ) text_input_ids = text_inputs.input_ids prompt_attention_mask = text_inputs.attention_mask + text_input_ids = text_input_ids.to(device) + prompt_attention_mask = prompt_attention_mask.to(device) - prompt_embeds, audio_prompt_embeds, prompt_attention_mask = self.text_encoder( - text_input_ids.to(device), - attention_mask=prompt_attention_mask.to(device), + text_encoder_outputs = self.text_encoder( + input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + text_encoder_hidden_states = text_encoder_outputs.hidden_states + text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) + sequence_lengths = prompt_attention_mask.sum(dim=-1) + + prompt_embeds = self._pack_text_embeds( + text_encoder_hidden_states, + sequence_lengths, + device=device, padding_side=self.tokenizer.padding_side, scale_factor=scale_factor, ) prompt_embeds = prompt_embeds.to(dtype=dtype) - audio_prompt_embeds = audio_prompt_embeds.to(dtype=dtype) # duplicate text embeddings for each generation per prompt, using mps friendly method _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - _, audio_seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, audio_seq_len, -1) - prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) - return prompt_embeds, audio_prompt_embeds, prompt_attention_mask + return prompt_embeds, prompt_attention_mask # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.encode_prompt def encode_prompt( @@ -331,9 +406,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL do_classifier_free_guidance: bool = True, num_videos_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, - audio_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, - negative_audio_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, max_sequence_length: int = 1024, @@ -376,7 +449,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL batch_size = prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, audio_prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( prompt=prompt, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, @@ -401,7 +474,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL " the batch size of `prompt`." ) - negative_prompt_embeds, negative_audio_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( prompt=negative_prompt, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, @@ -410,7 +483,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL dtype=dtype, ) - return prompt_embeds, audio_prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_audio_prompt_embeds, negative_prompt_attention_mask + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask # Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.check_inputs def check_inputs( @@ -727,10 +800,8 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL latents: Optional[torch.Tensor] = None, audio_latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, - audio_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, - negative_audio_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, decode_timestep: Union[float, List[float]] = 0.0, decode_noise_scale: Optional[Union[float, List[float]]] = None, @@ -793,17 +864,11 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. - audio_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings for audio processing. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. - negative_audio_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings for audio processing. For PixArt-Sigma this negative prompt should be "". If not - provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for negative text embeddings. decode_timestep (`float`, defaults to `0.0`): @@ -873,10 +938,8 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL # 3. Prepare text embeddings ( prompt_embeds, - audio_prompt_embeds, prompt_attention_mask, negative_prompt_embeds, - negative_audio_prompt_embeds, negative_prompt_attention_mask, ) = self.encode_prompt( prompt=prompt, @@ -884,9 +947,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL do_classifier_free_guidance=self.do_classifier_free_guidance, num_videos_per_prompt=num_videos_per_prompt, prompt_embeds=prompt_embeds, - audio_prompt_embeds=audio_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, - negative_audio_prompt_embeds=negative_audio_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, max_sequence_length=max_sequence_length, @@ -894,9 +955,13 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL ) if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - audio_prompt_embeds = torch.cat([negative_audio_prompt_embeds, audio_prompt_embeds], dim=0) prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( + prompt_embeds, additive_attention_mask, additive_mask=True + ) + # 4. Prepare latent variables if latents is None: image = self.video_processor.preprocess(image, height=height, width=width) @@ -1008,12 +1073,12 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL noise_pred_video, noise_pred_audio = self.transformer( hidden_states=latent_model_input, audio_hidden_states=audio_latent_model_input, - encoder_hidden_states=prompt_embeds, - audio_encoder_hidden_states=audio_prompt_embeds, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds, timestep=video_timestep, audio_timestep=timestep, - encoder_attention_mask=prompt_attention_mask, - audio_encoder_attention_mask=prompt_attention_mask, + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask, num_frames=latent_num_frames, height=latent_height, width=latent_width, diff --git a/src/diffusers/pipelines/ltx2/text_encoder.py b/src/diffusers/pipelines/ltx2/text_encoder.py deleted file mode 100644 index f15fa62224..0000000000 --- a/src/diffusers/pipelines/ltx2/text_encoder.py +++ /dev/null @@ -1,625 +0,0 @@ -# Copyright 2025 The Lightricks team and The HuggingFace Team. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -import math -from typing import Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import AutoConfig, AutoModel, Gemma3ForConditionalGeneration - -from ...configuration_utils import ConfigMixin, register_to_config -from ...models.attention import AttentionMixin, AttentionModuleMixin, FeedForward -from ...models.attention_dispatch import dispatch_attention_fn -from ...models.embeddings import get_1d_rotary_pos_embed -from ...models.modeling_utils import ModelMixin -from ...utils import is_torch_version, logging -from ..pipeline_loading_utils import _fetch_class_library_tuple - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -def apply_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: - cos, sin = freqs - x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2] - x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) - out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) - return out - - -# Copied from diffusers.models.transformers.transformer_ltx2.LTX2AudioVideoAttnProcessor -class LTX2AudioVideoAttnProcessor: - r""" - Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0) for the LTX-2.0 model. - Compared to the LTX-1.0 model, we allow the RoPE embeddings for the queries and keys to be separate so that we can - support audio-to-video (a2v) and video-to-audio (v2a) cross attention. - """ - - _attention_backend = None - _parallel_config = None - - def __init__(self): - if is_torch_version("<", "2.0"): - raise ValueError( - "LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation." - ) - - def __call__( - self, - attn: "LTX2Attention", - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - ) -> torch.Tensor: - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - - query = attn.to_q(hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - query = attn.norm_q(query) - key = attn.norm_k(key) - - if query_rotary_emb is not None: - query = apply_rotary_emb(query, query_rotary_emb) - key = apply_rotary_emb(key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb) - - query = query.unflatten(2, (attn.heads, -1)) - key = key.unflatten(2, (attn.heads, -1)) - value = value.unflatten(2, (attn.heads, -1)) - - hidden_states = dispatch_attention_fn( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=0.0, - is_causal=False, - backend=self._attention_backend, - parallel_config=self._parallel_config, - ) - hidden_states = hidden_states.flatten(2, 3) - hidden_states = hidden_states.to(query.dtype) - - hidden_states = attn.to_out[0](hidden_states) - hidden_states = attn.to_out[1](hidden_states) - return hidden_states - - -# Copied from diffusers.models.transformers.transformer_ltx2.LTX2Attention -class LTX2Attention(torch.nn.Module, AttentionModuleMixin): - r""" - Attention class for all LTX-2.0 attention layers. Compared to LTX-1.0, this supports specifying the query and key - RoPE embeddings separately for audio-to-video (a2v) and video-to-audio (v2a) cross-attention. - """ - - _default_processor_cls = LTX2AudioVideoAttnProcessor - _available_processors = [LTX2AudioVideoAttnProcessor] - - def __init__( - self, - query_dim: int, - heads: int = 8, - kv_heads: int = 8, - dim_head: int = 64, - dropout: float = 0.0, - bias: bool = True, - cross_attention_dim: Optional[int] = None, - out_bias: bool = True, - qk_norm: str = "rms_norm_across_heads", - norm_eps: float = 1e-6, - norm_elementwise_affine: bool = True, - processor=None, - ): - super().__init__() - if qk_norm != "rms_norm_across_heads": - raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.") - - self.head_dim = dim_head - self.inner_dim = dim_head * heads - self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads - self.query_dim = query_dim - self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - self.use_bias = bias - self.dropout = dropout - self.out_dim = query_dim - self.heads = heads - - self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine) - self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine) - self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) - self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) - self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) - self.to_out = torch.nn.ModuleList([]) - self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) - self.to_out.append(torch.nn.Dropout(dropout)) - - if processor is None: - processor = self._default_processor_cls() - self.set_processor(processor) - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, - ) -> torch.Tensor: - attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) - unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters] - if len(unused_kwargs) > 0: - logger.warning( - f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." - ) - kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} - hidden_states = self.processor( - self, hidden_states, encoder_hidden_states, attention_mask, query_rotary_emb, key_rotary_emb, **kwargs - ) - return hidden_states - - -class LTX2RotaryPosEmbed1d(nn.Module): - """ - 1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors. - """ - - def __init__( - self, - dim: int, - base_seq_len: int = 4096, - theta: float = 10000.0, - double_precision: bool = True, - ): - super().__init__() - self.dim = dim - self.base_seq_len = base_seq_len - self.theta = theta - self.double_precision = double_precision - - def forward( - self, - batch_size: int, - pos: int, - device: Union[str, torch.device], - ) -> Tuple[torch.Tensor, torch.Tensor]: - # 1. Get 1D position ids - grid_1d = torch.arange(pos, dtype=torch.float32, device=device) - # Get fractional indices relative to self.base_seq_len - grid_1d = grid_1d / self.base_seq_len - grid = grid_1d.unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len] - - # 2. Calculate 1D RoPE frequencies - num_rope_elems = 2 # 1 (because 1D) * 2 (for cos, sin) = 2 - start = 1.0 - end = self.theta - if self.double_precision: - pow_indices = np.power( - self.theta, - np.linspace( - np.log(start) / np.log(self.theta), - np.log(end) / np.log(self.theta), - self.dim // num_rope_elems, - dtype=np.float64, - ), - ) - freqs = torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32, device=device) - else: - freqs = self.theta ** torch.linspace( - start=math.log(start, self.theta), - end=math.log(end, self.theta), - steps=self.dim // num_rope_elems, - device=device, - dtype=torch.float32, - ) - freqs = freqs * math.pi / 2.0 - - # 3. Matrix-vector outer product between pos ids of shape (batch_size, seq_len) and freqs vector of shape - # (self.dim // 2,). - freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, seq_len, self.dim // 2] - - # 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim - cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) - sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) - - if self.dim % num_rope_elems != 0: - cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems]) - sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems]) - cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) - sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) - - return cos_freqs, sin_freqs - - -class LTX2TransformerBlock1d(nn.Module): - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - activation_fn: str = "gelu-approximate", - eps: float = 1e-6, - ): - super().__init__() - - self.norm1 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) - self.attn1 = LTX2Attention( - query_dim=dim, - heads=num_attention_heads, - kv_heads=num_attention_heads, - dim_head=attention_head_dim, - processor=LTX2AudioVideoAttnProcessor(), - ) - - self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False) - self.ff = FeedForward(dim, activation_fn=activation_fn) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - norm_hidden_states = self.norm1(hidden_states) - attn_hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, query_rotary_emb=rotary_emb) - hidden_states = hidden_states + attn_hidden_states - - norm_hidden_states = self.norm2(hidden_states) - ff_hidden_states = self.ff(norm_hidden_states) - hidden_states = hidden_states + ff_hidden_states - - return hidden_states - - -class LTX2ConnectorTransformer1d(nn.Module): - """ - A 1D sequence transformer for modalities such as text. - - In LTX 2.0, this is used to process the text encoder hidden states for each of the video and audio streams. - """ - _supports_gradient_checkpointing = True - - def __init__( - self, - num_attention_heads: int = 30, - attention_head_dim: int = 128, - num_layers: int = 2, - num_learnable_registers: Optional[int] = 128, - rope_base_seq_len: int = 4096, - rope_theta: float = 10000.0, - rope_double_precision: bool = True, - eps: float = 1e-6, - causal_temporal_positioning: bool = False, - ): - super().__init__() - self.num_attention_heads = num_attention_heads - self.inner_dim = num_attention_heads * attention_head_dim - self.causal_temporal_positioning = causal_temporal_positioning - - self.num_learnable_registers = num_learnable_registers - self.learnable_registers = None - if num_learnable_registers is not None: - init_registers = torch.rand(num_learnable_registers, self.inner_dim) * 2.0 - 1.0 - self.learnable_registers = torch.nn.Parameter(init_registers) - - self.rope = LTX2RotaryPosEmbed1d( - self.inner_dim, base_seq_len=rope_base_seq_len, theta=rope_theta, double_precision=rope_double_precision - ) - - self.transformer_blocks = torch.nn.ModuleList( - [ - LTX2TransformerBlock1d( - dim=self.inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - ) - for _ in range(num_layers) - ] - ) - - self.norm_out = torch.nn.RMSNorm(self.inner_dim, eps=eps, elementwise_affine=False) - - self.gradient_checkpointing = False - - def forward( - self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - # hidden_states shape: [batch_size, seq_len, hidden_dim] - # attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len] - batch_size, seq_len, _ = hidden_states.shape - - # 1. Replace padding with learned registers, if using - if self.learnable_registers is not None: - if seq_len % self.num_learnable_registers != 0: - raise ValueError( - f"The `hidden_states` sequence length {hidden_states.shape[1]} should be divisible by the number" - f" of learnable registers {self.num_learnable_registers}" - ) - - num_register_repeats = seq_len // self.num_learnable_registers - registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim] - - binary_attn_mask = (attention_mask >= -9000.0).int() - if binary_attn_mask.ndim == 4: - binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L] - - hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)] - valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded] - pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens] - padded_hidden_states = [ - F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths) - ] - padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D] - - flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1] - hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers - - # Overwrite attention_mask with an all-zeros mask if using registers. - attention_mask = torch.zeros_like(attention_mask) - - # 2. Calculate 1D RoPE positional embeddings - rotary_emb = self.rope(batch_size, seq_len, device=hidden_states.device) - - # 3. Run 1D transformer blocks - for block in self.transformer_blocks: - if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(block, hidden_states, attention_mask, rotary_emb) - else: - hidden_states = block(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb) - - hidden_states = self.norm_out(hidden_states) - - return hidden_states, attention_mask - - -class LTX2AudioVisualTextEncoder(ModelMixin, ConfigMixin): - ignore_for_config = ["text_model"] - - @register_to_config - def __init__( - self, - text_model: Optional[Gemma3ForConditionalGeneration] = None, - text_model_id: str = "google/gemma-3-12b-it-qat-q4_0-unquantized", - text_encoder_hidden_dim: Optional[int] = 3840, - text_proj_in_factor: Optional[int] = 49, # Num layers in text encoder + 1 - video_connector_num_attention_heads: int = 30, - video_connector_attention_head_dim: int = 128, - video_connector_num_layers: int = 2, - video_connector_num_learnable_registers: int = 128, - audio_connector_num_attention_heads: int = 30, - audio_connector_attention_head_dim: int = 128, - audio_connector_num_layers: int = 2, - audio_connector_num_learnable_registers: Optional[int] = 128, - rope_base_seq_len: int = 4096, - rope_theta: float = 10000.0, - rope_double_precision: bool = True, - causal_temporal_positioning: bool = False, - config_only: bool = True, - ): - super().__init__() - if text_model is None: - self.set_base_text_encoder(text_model_id, config_only=config_only) - else: - self.base_text_encoder = text_model - - if text_encoder_hidden_dim is None: - if hasattr(self.base_text_encoder, "config"): - if hasattr(self.base_text_encoder.config, "hidden_size"): - text_encoder_hidden_dim = getattr(self.base_text_encoder.config, "hidden_size", None) - elif hasattr(self.base_text_encoder.config, "text_config"): - text_encoder_hidden_dim = getattr(self.base_text_encoder.config.text_config, "hidden_size", None) - if text_encoder_hidden_dim is None: - raise ValueError( - "`text_encoder_hidden_dim` is `None` and it cannot be inferred, please provide a value for it." - ) - - if text_proj_in_factor is None: - num_layers = None - if hasattr(self.base_text_encoder, "config"): - if hasattr(self.base_text_encoder.config, "num_hidden_layers"): - num_layers = getattr(self.base_text_encoder.config, "num_hidden_layers", None) - elif hasattr(self.base_text_encoder.config, "text_config"): - num_layers = getattr(self.base_text_encoder.config.text_config, "num_hidden_layers", None) - if num_layers is None: - raise ValueError( - "`text_proj_in_factor` is `None` and it cannot be inferred, please provide a value for it." - ) - text_proj_in_factor = num_layers + 1 - - self.text_proj_in = nn.Linear( - text_encoder_hidden_dim * text_proj_in_factor, text_encoder_hidden_dim, bias=False - ) - - self.video_connector = LTX2ConnectorTransformer1d( - num_attention_heads=video_connector_num_attention_heads, - attention_head_dim=video_connector_attention_head_dim, - num_layers=video_connector_num_layers, - num_learnable_registers=video_connector_num_learnable_registers, - rope_base_seq_len=rope_base_seq_len, - rope_theta=rope_theta, - rope_double_precision=rope_double_precision, - causal_temporal_positioning=causal_temporal_positioning, - ) - self.audio_connector = LTX2ConnectorTransformer1d( - num_attention_heads=audio_connector_num_attention_heads, - attention_head_dim=audio_connector_attention_head_dim, - num_layers=audio_connector_num_layers, - num_learnable_registers=audio_connector_num_learnable_registers, - rope_base_seq_len=rope_base_seq_len, - rope_theta=rope_theta, - rope_double_precision=rope_double_precision, - causal_temporal_positioning=causal_temporal_positioning, - ) - - def set_base_text_encoder( - self, base_text_encoder_id: str = "google/gemma-3-12b-it-qat-q4_0-unquantized", config_only: bool = True - ): - if config_only: - base_text_encoder_config = AutoConfig.from_pretrained(base_text_encoder_id) - base_text_encoder = AutoModel.from_config(base_text_encoder_config) - else: - base_text_encoder = AutoModel.from_pretrained(base_text_encoder_id) - self.base_text_encoder = base_text_encoder - - @staticmethod - def pack_text_embeds( - text_hidden_states: torch.Tensor, - sequence_lengths: torch.Tensor, - device: Union[str, torch.device], - padding_side: str = "left", - scale_factor: int = 8, - eps: float = 1e-6, - ) -> torch.Tensor: - """ - Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and - per-layer in a masked fashion (only over non-padded positions). - - Args: - text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`): - Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`). - sequence_lengths (`torch.Tensor of shape `(batch_size,)`): - The number of valid (non-padded) tokens for each batch instance. - device: (`str` or `torch.device`, *optional*): - torch device to place the resulting embeddings on - padding_side: (`str`, *optional*, defaults to `"left"`): - Whether the text tokenizer performs padding on the `"left"` or `"right"`. - scale_factor (`int`, *optional*, defaults to `8`): - Scaling factor to multiply the normalized hidden states by. - eps (`float`, *optional*, defaults to `1e-6`): - A small positive value for numerical stability when performing normalization. - - Returns: - `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`: - Normed and flattened text encoder hidden states. - """ - batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape - original_dtype = text_hidden_states.dtype - - # Create padding mask - token_indices = torch.arange(seq_len, device=device).unsqueeze(0) - if padding_side == "right": - # For right padding, valid tokens are from 0 to sequence_length-1 - mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len] - elif padding_side == "left": - # For left padding, valid tokens are from (T - sequence_length) to T-1 - start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1] - mask = token_indices >= start_indices # [B, T] - else: - raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") - mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1] - - # Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len) - masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) - num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) - masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) - - # Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len) - x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) - x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) - - # Normalization - normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) - normalized_hidden_states = normalized_hidden_states * scale_factor - - # Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers) - normalized_hidden_states = normalized_hidden_states.flatten(2) - mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) - normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) - normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) - return normalized_hidden_states - - def run_connectors( - self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Run LTX 2.0-specific text embedding post-processing logic on top of the base text encoder hidden_states. - - Args: - text_encoder_hidden_states (`torch.Tensor`): - Text encoder packed hidden_states of shape `(batch_size, seq_len, hidden_dim * (num_layers + 1))`. - attention_mask (`torch.Tensor`): - Attention mask of shape `(batch_size, seq_len)`. - - Returns: - `Tuple(torch.Tensor, torch.Tensor, torch.Tensor)]`: - Returns a 3-tuple of tensors where the first element is the video text embeddings of shape - `(batch_size, seq_len, hidden_dim)`, the second element is the audio text embeddings of shape - `(batch_size, seq_len, hidden_dim)`, and the third element is an attention mask of shape - `(batch_size, seq_len)`. - """ - # Convert to additive attention mask - text_dtype = text_encoder_hidden_states.dtype - connector_attn_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) - connector_attn_mask = connector_attn_mask.to(text_dtype) * torch.finfo(text_dtype).max - - text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states) - - video_text_embedding, new_attn_mask = self.video_connector( - text_encoder_hidden_states, connector_attn_mask - ) - - attn_mask = (new_attn_mask < 1e-6).to(torch.int64) - attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1) - video_text_embedding = video_text_embedding * attn_mask - new_attn_mask = attn_mask.squeeze(-1) - - audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, connector_attn_mask) - - return video_text_embedding, audio_text_embedding, new_attn_mask - - def forward( - self, - text_input_ids, - attention_mask: Optional[torch.Tensor] = None, - padding_side: str = "left", - scale_factor: int = 8, - ): - text_encoder_outputs = self.base_text_encoder( - input_ids=text_input_ids, attention_mask=attention_mask, output_hidden_states=True - ) - - text_encoder_hidden_states = text_encoder_outputs.hidden_states - text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) - sequence_lengths = attention_mask.sum(dim=-1) - - text_encoder_hidden_states = self.pack_text_embeds( - text_encoder_hidden_states, - sequence_lengths, - device=text_encoder_hidden_states.device, - padding_side=padding_side, - scale_factor=scale_factor, - ) - - video_text_embedding, audio_text_embedding, new_attn_mask = self.run_connectors( - text_encoder_hidden_states, attention_mask - ) - - return video_text_embedding, audio_text_embedding, new_attn_mask diff --git a/tests/models/transformers/test_models_transformer_ltx2.py b/tests/models/transformers/test_models_transformer_ltx2.py index 079273e975..1b0a7dd28f 100644 --- a/tests/models/transformers/test_models_transformer_ltx2.py +++ b/tests/models/transformers/test_models_transformer_ltx2.py @@ -99,6 +99,7 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase): "num_layers": 2, "qk_norm": "rms_norm_across_heads", "caption_channels": 16, + "rope_double_precision": False, } inputs_dict = self.dummy_input return init_dict, inputs_dict