1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Merge pull request #4 from huggingface/ltx-2-t2v-pipeline

LTX 2.0 Text-to-Video (T2V) Pipeline
This commit is contained in:
dg845
2025-12-23 21:29:25 -08:00
committed by GitHub
11 changed files with 2160 additions and 18 deletions

View File

@@ -7,9 +7,17 @@ import safetensors.torch
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from transformers import AutoModel, AutoTokenizer
from diffusers import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, LTX2VideoTransformer3DModel
from diffusers import (
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
FlowMatchEulerDiscreteScheduler,
LTX2Pipeline,
LTX2VideoTransformer3DModel,
)
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.pipelines.ltx2.text_encoder import LTX2AudioVisualTextEncoder
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
@@ -71,6 +79,15 @@ LTX_2_0_VOCODER_RENAME_DICT = {
"conv_post": "conv_out",
}
LTX_2_0_TEXT_ENCODER_RENAME_DICT = {
"video_embeddings_connector": "video_connector",
"audio_embeddings_connector": "audio_connector",
"transformer_1d_blocks": "transformer_blocks",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)
@@ -125,6 +142,8 @@ LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
LTX_2_0_TEXT_ENCODER_SPECIAL_KEYS_REMAP = {}
def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "test":
@@ -163,7 +182,10 @@ def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str,
"attention_bias": True,
"attention_out_bias": True,
"rope_theta": 10000.0,
"rope_double_precision": False,
"causal_offset": 1,
"timestep_scale_multiplier": 1000,
"cross_attn_timestep_scale_multiplier": 1,
},
}
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
@@ -203,7 +225,10 @@ def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str,
"attention_bias": True,
"attention_out_bias": True,
"rope_theta": 10000.0,
"rope_double_precision": True,
"causal_offset": 1,
"timestep_scale_multiplier": 1000,
"cross_attn_timestep_scale_multiplier": 1,
},
}
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
@@ -442,6 +467,82 @@ def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> D
return vocoder
def get_ltx2_text_encoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"text_encoder_hidden_dim": 3840,
"text_proj_in_factor": 49,
"video_connector_num_attention_heads": 30,
"video_connector_attention_head_dim": 128,
"video_connector_num_layers": 2,
"video_connector_num_learnable_registers": 128,
"audio_connector_num_attention_heads": 30,
"audio_connector_attention_head_dim": 128,
"audio_connector_num_layers": 2,
"audio_connector_num_learnable_registers": 128,
"rope_base_seq_len": 4096,
"rope_theta": 10000.0,
"rope_double_precision": True,
"causal_temporal_positioning": False,
},
}
rename_dict = LTX_2_0_TEXT_ENCODER_RENAME_DICT
special_keys_remap = LTX_2_0_TEXT_ENCODER_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def get_text_encoder_keys_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str = "model.diffusion_model."):
model_state_dict = {}
model_state_dict["text_proj_in.weight"] = combined_ckpt["text_embedding_projection.aggregate_embed.weight"]
text_encoder_submodules = ["video_embeddings_connector", "audio_embeddings_connector"]
for param_name, param in combined_ckpt.items():
if param_name.startswith(prefix):
new_param_name = param_name.replace(prefix, "")
for submodule_name in text_encoder_submodules:
if new_param_name.startswith(submodule_name):
model_state_dict[new_param_name] = param
break
return model_state_dict
def convert_ltx2_text_encoder(original_state_dict: Dict[str, Any], version: str, text_model_id: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_text_encoder_config(version)
diffusers_config = config["diffusers_config"]
diffusers_config["text_model_id"] = text_model_id
diffusers_config["config_only"] = True
with init_empty_weights():
text_encoder = LTX2AudioVisualTextEncoder.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
base_text_model = AutoModel.from_pretrained(text_model_id)
base_text_model_state_dict= base_text_model.state_dict()
base_text_model_state_dict = {"base_text_encoder." + k: v for k, v in base_text_model_state_dict.items()}
combined_state_dict = {**original_state_dict, **base_text_model_state_dict}
text_encoder.load_state_dict(combined_state_dict, strict=True, assign=True)
return text_encoder
def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]:
if args.original_state_dict_repo_id is not None:
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
@@ -528,11 +629,24 @@ def get_args():
parser.add_argument(
"--vocoder_filename", default=None, type=str, help="Vocoder filename; overrides combined ckpt if set"
)
parser.add_argument(
"--text_encoder_model_id",
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
type=str,
help="HF Hub id for the LTX 2.0 base text encoder model",
)
parser.add_argument(
"--tokenizer_id",
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
type=str,
help="HF Hub id for the LTX 2.0 text tokenizer",
)
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model")
parser.add_argument("--text_encoder", action="store_true", help="Whether to conver the text encoder")
parser.add_argument(
"--full_pipeline",
action="store_true",
@@ -543,6 +657,7 @@ def get_args():
parser.add_argument("--audio_vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--dit_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--vocoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--text_encoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
@@ -567,9 +682,12 @@ def main(args):
audio_vae_dtype = DTYPE_MAPPING[args.audio_vae_dtype]
dit_dtype = DTYPE_MAPPING[args.dit_dtype]
vocoder_dtype = DTYPE_MAPPING[args.vocoder_dtype]
text_encoder_dtype = DTYPE_MAPPING[args.text_encoder_dtype]
combined_ckpt = None
load_combined_models = any([args.vae, args.audio_vae, args.dit, args.vocoder, args.full_pipeline])
load_combined_models = any(
[args.vae, args.audio_vae, args.dit, args.vocoder, args.text_encoder, args.full_pipeline]
)
if args.combined_filename is not None and load_combined_models:
combined_ckpt = load_original_checkpoint(args, filename=args.combined_filename)
@@ -609,8 +727,37 @@ def main(args):
if not args.full_pipeline:
vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder"))
if args.text_encoder or args.full_pipeline:
text_encoder_ckpt = get_text_encoder_keys_from_combined_ckpt(combined_ckpt)
text_encoder = convert_ltx2_text_encoder(text_encoder_ckpt, args.version, args.text_encoder_model_id)
if not args.full_pipeline:
text_encoder.to(text_encoder_dtype).save_pretrained(os.path.join(args.output_path, "text_encoder"))
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id)
if not args.full_pipeline:
tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer"))
if args.full_pipeline:
pass
scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
base_shift=0.95,
max_shift=2.05,
base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=0.1,
)
pipe = LTX2Pipeline(
scheduler=scheduler,
vae=vae,
audio_vae=audio_vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
vocoder=vocoder,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if __name__ == '__main__':

View File

@@ -0,0 +1,215 @@
import argparse
import os
from fractions import Fraction
from typing import Optional
import av # Needs to be installed separately (`pip install av`)
import torch
from diffusers import LTX2Pipeline
# Video export functions copied from original LTX 2.0 code
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
"""
Prepare the audio stream for writing.
"""
audio_stream = container.add_stream("aac", rate=audio_sample_rate)
audio_stream.codec_context.sample_rate = audio_sample_rate
audio_stream.codec_context.layout = "stereo"
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
return audio_stream
def _resample_audio(
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
) -> None:
cc = audio_stream.codec_context
# Use the encoder's format/layout/rate as the *target*
target_format = cc.format or "fltp" # AAC → usually fltp
target_layout = cc.layout or "stereo"
target_rate = cc.sample_rate or frame_in.sample_rate
audio_resampler = av.audio.resampler.AudioResampler(
format=target_format,
layout=target_layout,
rate=target_rate,
)
audio_next_pts = 0
for rframe in audio_resampler.resample(frame_in):
if rframe.pts is None:
rframe.pts = audio_next_pts
audio_next_pts += rframe.samples
rframe.sample_rate = frame_in.sample_rate
container.mux(audio_stream.encode(rframe))
# flush audio encoder
for packet in audio_stream.encode():
container.mux(packet)
def _write_audio(
container: av.container.Container,
audio_stream: av.audio.AudioStream,
samples: torch.Tensor,
audio_sample_rate: int,
) -> None:
if samples.ndim == 1:
samples = samples[:, None]
if samples.shape[1] != 2 and samples.shape[0] == 2:
samples = samples.T
if samples.shape[1] != 2:
raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
if samples.dtype != torch.int16:
samples = torch.clip(samples, -1.0, 1.0)
samples = (samples * 32767.0).to(torch.int16)
frame_in = av.AudioFrame.from_ndarray(
samples.contiguous().reshape(1, -1).cpu().numpy(),
format="s16",
layout="stereo",
)
frame_in.sample_rate = audio_sample_rate
_resample_audio(container, audio_stream, frame_in)
def encode_video(
video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str
) -> None:
video_np = video.cpu().numpy()
_, height, width, _ = video_np.shape
container = av.open(output_path, mode="w")
stream = container.add_stream("libx264", rate=int(fps))
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
if audio is not None:
if audio_sample_rate is None:
raise ValueError("audio_sample_rate is required when audio is provided")
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
for frame_array in video_np:
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
for packet in stream.encode(frame):
container.mux(packet)
# Flush encoder
for packet in stream.encode():
container.mux(packet)
if audio is not None:
_write_audio(container, audio_stream, audio, audio_sample_rate)
container.close()
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str, default="diffusers-internal-dev/new-ltx-model")
parser.add_argument("--revision", type=str, default="main")
parser.add_argument(
"--prompt",
type=str,
default="A video of a dog dancing to energetic electronic dance music",
)
parser.add_argument(
"--negative_prompt",
type=str,
default=(
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio,incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
),
)
parser.add_argument("--num_inference_steps", type=int, default=40)
parser.add_argument("--height", type=int, default=512)
parser.add_argument("--width", type=int, default=768)
parser.add_argument("--num_frames", type=int, default=121)
parser.add_argument("--frame_rate", type=float, default=25.0)
parser.add_argument("--guidance_scale", type=float, default=3.0)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--dtype", type=str, default="bf16")
parser.add_argument("--cpu_offload", action="store_true")
parser.add_argument(
"--output_dir",
type=str,
default="/home/daniel_gu/samples",
help="Output directory for generated video",
)
parser.add_argument(
"--output_filename",
type=str,
default="ltx2_sample_video.mp4",
help="Filename of the exported generated video",
)
args = parser.parse_args()
args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32
return args
def main(args):
pipeline = LTX2Pipeline.from_pretrained(
args.model_id,
revision=args.revision,
torch_dtype=args.dtype,
)
pipeline.to(device=args.device)
if args.cpu_offload:
pipeline.enable_model_cpu_offload()
video, audio = pipeline(
prompt=args.prompt,
negative_prompt=args.negative_prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
frame_rate=args.frame_rate,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
generator=torch.Generator(device=args.device).manual_seed(args.seed),
output_type="np",
return_dict=False,
)
# Convert video to uint8 (but keep as NumPy array)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
encode_video(
video[0],
fps=args.frame_rate,
audio=audio[0].float().cpu(),
audio_sample_rate=pipeline.vocoder.config.output_sampling_rate, # should be 24000
output_path=os.path.join(args.output_dir, args.output_filename),
)
if __name__ == '__main__':
args = parse_args()
main(args)

View File

@@ -537,6 +537,7 @@ else:
"LTXImageToVideoPipeline",
"LTXLatentUpsamplePipeline",
"LTXPipeline",
"LTX2Pipeline",
"LucyEditPipeline",
"Lumina2Pipeline",
"Lumina2Text2ImgPipeline",
@@ -1243,6 +1244,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LTXImageToVideoPipeline,
LTXLatentUpsamplePipeline,
LTXPipeline,
LTX2Pipeline,
LucyEditPipeline,
Lumina2Pipeline,
Lumina2Text2ImgPipeline,

View File

@@ -605,6 +605,10 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
mel_bins=mel_bins,
)
# TODO: calculate programmatically instead of hardcoding
self.temporal_compression_ratio = LATENT_DOWNSAMPLE_FACTOR # 4
# TODO: confirm whether the mel compression ratio below is correct
self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR
self.use_slicing = False
@apply_forward_hook

View File

@@ -18,6 +18,7 @@ import math
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
@@ -561,6 +562,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
theta: float = 10000.0,
causal_offset: int = 1,
modality: str = "video",
double_precision: bool = True,
) -> None:
super().__init__()
@@ -586,6 +588,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
self.modality = modality
if self.modality not in ["video", "audio"]:
raise ValueError(f"Modality {modality} is not supported. Supported modalities are `video` and `audio`.")
self.double_precision = double_precision
def prepare_video_coords(
self,
@@ -779,14 +782,26 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
# 4. Create a 1D grid of frequencies for RoPE
start = 1.0
end = self.theta
freqs = self.theta ** torch.linspace(
start=math.log(start, self.theta),
end=math.log(end, self.theta),
steps=self.dim // num_rope_elems,
device=device,
dtype=torch.float32,
)
freqs = freqs * math.pi / 2.0
if self.double_precision:
pow_indices = np.power(
self.theta,
np.linspace(
np.log(start) / np.log(self.theta),
np.log(end) / np.log(self.theta),
self.dim // num_rope_elems,
dtype=np.float64,
),
)
freqs = torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32, device=device)
else:
freqs = self.theta ** torch.linspace(
start=math.log(start, self.theta),
end=math.log(end, self.theta),
steps=self.dim // num_rope_elems,
device=device,
dtype=torch.float32,
)
freqs = freqs * math.pi / 2.0
# 5. Tensor-vector outer product between pos ids tensor of shape (B, 3, num_patches) and freqs vector of shape
# (self.dim // num_elems,)
@@ -885,7 +900,10 @@ class LTX2VideoTransformer3DModel(
attention_bias: bool = True,
attention_out_bias: bool = True,
rope_theta: float = 10000.0,
rope_double_precision: bool = True,
causal_offset: int = 1,
timestep_scale_multiplier: int = 1000,
cross_attn_timestep_scale_multiplier: int = 1,
) -> None:
super().__init__()
@@ -951,6 +969,7 @@ class LTX2VideoTransformer3DModel(
theta=rope_theta,
causal_offset=causal_offset,
modality="video",
double_precision=rope_double_precision,
)
self.audio_rope = LTX2AudioVideoRotaryPosEmbed(
dim=audio_inner_dim,
@@ -963,6 +982,7 @@ class LTX2VideoTransformer3DModel(
theta=rope_theta,
causal_offset=causal_offset,
modality="audio",
double_precision=rope_double_precision,
)
# Audio-to-Video, Video-to-Audio Cross-Attention
@@ -977,6 +997,7 @@ class LTX2VideoTransformer3DModel(
theta=rope_theta,
causal_offset=causal_offset,
modality="video",
double_precision=rope_double_precision,
)
self.cross_attn_audio_rope = LTX2AudioVideoRotaryPosEmbed(
dim=audio_cross_attention_dim,
@@ -988,6 +1009,7 @@ class LTX2VideoTransformer3DModel(
theta=rope_theta,
causal_offset=causal_offset,
modality="audio",
double_precision=rope_double_precision,
)
# 5. Transformer Blocks
@@ -1038,8 +1060,6 @@ class LTX2VideoTransformer3DModel(
audio_num_frames: Optional[int] = None,
video_coords: Optional[torch.Tensor] = None,
audio_coords: Optional[torch.Tensor] = None,
timestep_scale_multiplier: int = 1000,
cross_attn_timestep_scale_multiplier: int = 1,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> torch.Tensor:
@@ -1109,9 +1129,7 @@ class LTX2VideoTransformer3DModel(
audio_hidden_states = self.audio_proj_in(audio_hidden_states)
# 3. Prepare timestep embeddings and modulation parameters
# Scale timestep
timestep = timestep * timestep_scale_multiplier
timestep_cross_attn_gate_scale_factor = cross_attn_timestep_scale_multiplier / timestep_scale_multiplier
timestep_cross_attn_gate_scale_factor = self.config.cross_attn_timestep_scale_multiplier / self.config.timestep_scale_multiplier
# 3.1. Prepare global modality (video and audio) timestep embedding and modulation parameters
# temb is used in the transformer blocks (as expected), while embedded_timestep is used for the output layer

View File

@@ -288,6 +288,7 @@ else:
"LTXConditionPipeline",
"LTXLatentUpsamplePipeline",
]
_import_structure["ltx2"] = ["LTX2Pipeline"]
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
_import_structure["lucy"] = ["LucyEditPipeline"]
@@ -719,6 +720,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusionXL,
)
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
from .ltx2 import LTX2Pipeline
from .lucy import LucyEditPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline

View File

@@ -0,0 +1,52 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_ltx2"] = ["LTX2Pipeline"]
_import_structure["text_encoder"] = ["LTX2AudioVisualTextEncoder"]
_import_structure["vocoder"] = ["LTX2Vocoder"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_ltx2 import LTX2Pipeline
from .text_encoder import LTX2AudioVisualTextEncoder
from .vocoder import LTX2Vocoder
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,23 @@
from dataclasses import dataclass
import torch
from diffusers.utils import BaseOutput
@dataclass
class LTX2PipelineOutput(BaseOutput):
r"""
Output class for LTX pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
audio (`torch.Tensor`, `np.ndarray`):
TODO
"""
frames: torch.Tensor
audio: torch.Tensor

View File

@@ -0,0 +1,625 @@
# Copyright 2025 The Lightricks team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoModel, Gemma3ForConditionalGeneration
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ...models.attention_dispatch import dispatch_attention_fn
from ...models.embeddings import get_1d_rotary_pos_embed
from ...models.modeling_utils import ModelMixin
from ...utils import is_torch_version, logging
from ..pipeline_loading_utils import _fetch_class_library_tuple
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def apply_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
cos, sin = freqs
x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
# Copied from diffusers.models.transformers.transformer_ltx2.LTX2AudioVideoAttnProcessor
class LTX2AudioVideoAttnProcessor:
r"""
Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0) for the LTX-2.0 model.
Compared to the LTX-1.0 model, we allow the RoPE embeddings for the queries and keys to be separate so that we can
support audio-to-video (a2v) and video-to-audio (v2a) cross attention.
"""
_attention_backend = None
_parallel_config = None
def __init__(self):
if is_torch_version("<", "2.0"):
raise ValueError(
"LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
)
def __call__(
self,
attn: "LTX2Attention",
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.norm_q(query)
key = attn.norm_k(key)
if query_rotary_emb is not None:
query = apply_rotary_emb(query, query_rotary_emb)
key = apply_rotary_emb(key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
# Copied from diffusers.models.transformers.transformer_ltx2.LTX2Attention
class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
r"""
Attention class for all LTX-2.0 attention layers. Compared to LTX-1.0, this supports specifying the query and key
RoPE embeddings separately for audio-to-video (a2v) and video-to-audio (v2a) cross-attention.
"""
_default_processor_cls = LTX2AudioVideoAttnProcessor
_available_processors = [LTX2AudioVideoAttnProcessor]
def __init__(
self,
query_dim: int,
heads: int = 8,
kv_heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = True,
cross_attention_dim: Optional[int] = None,
out_bias: bool = True,
qk_norm: str = "rms_norm_across_heads",
norm_eps: float = 1e-6,
norm_elementwise_affine: bool = True,
processor=None,
):
super().__init__()
if qk_norm != "rms_norm_across_heads":
raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.")
self.head_dim = dim_head
self.inner_dim = dim_head * heads
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
self.query_dim = query_dim
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.use_bias = bias
self.dropout = dropout
self.out_dim = query_dim
self.heads = heads
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
self.to_out = torch.nn.ModuleList([])
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(torch.nn.Dropout(dropout))
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
if len(unused_kwargs) > 0:
logger.warning(
f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
hidden_states = self.processor(
self, hidden_states, encoder_hidden_states, attention_mask, query_rotary_emb, key_rotary_emb, **kwargs
)
return hidden_states
class LTX2RotaryPosEmbed1d(nn.Module):
"""
1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors.
"""
def __init__(
self,
dim: int,
base_seq_len: int = 4096,
theta: float = 10000.0,
double_precision: bool = True,
):
super().__init__()
self.dim = dim
self.base_seq_len = base_seq_len
self.theta = theta
self.double_precision = double_precision
def forward(
self,
batch_size: int,
pos: int,
device: Union[str, torch.device],
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Get 1D position ids
grid_1d = torch.arange(pos, dtype=torch.float32, device=device)
# Get fractional indices relative to self.base_seq_len
grid_1d = grid_1d / self.base_seq_len
grid = grid_1d.unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len]
# 2. Calculate 1D RoPE frequencies
num_rope_elems = 2 # 1 (because 1D) * 2 (for cos, sin) = 2
start = 1.0
end = self.theta
if self.double_precision:
pow_indices = np.power(
self.theta,
np.linspace(
np.log(start) / np.log(self.theta),
np.log(end) / np.log(self.theta),
self.dim // num_rope_elems,
dtype=np.float64,
),
)
freqs = torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32, device=device)
else:
freqs = self.theta ** torch.linspace(
start=math.log(start, self.theta),
end=math.log(end, self.theta),
steps=self.dim // num_rope_elems,
device=device,
dtype=torch.float32,
)
freqs = freqs * math.pi / 2.0
# 3. Matrix-vector outer product between pos ids of shape (batch_size, seq_len) and freqs vector of shape
# (self.dim // 2,).
freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, seq_len, self.dim // 2]
# 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim
cos_freqs = freqs.cos().repeat_interleave(2, dim=-1)
sin_freqs = freqs.sin().repeat_interleave(2, dim=-1)
if self.dim % num_rope_elems != 0:
cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems])
sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems])
cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
return cos_freqs, sin_freqs
class LTX2TransformerBlock1d(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
activation_fn: str = "gelu-approximate",
eps: float = 1e-6,
):
super().__init__()
self.norm1 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
self.attn1 = LTX2Attention(
query_dim=dim,
heads=num_attention_heads,
kv_heads=num_attention_heads,
dim_head=attention_head_dim,
processor=LTX2AudioVideoAttnProcessor(),
)
self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
self.ff = FeedForward(dim, activation_fn=activation_fn)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
norm_hidden_states = self.norm1(hidden_states)
attn_hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, query_rotary_emb=rotary_emb)
hidden_states = hidden_states + attn_hidden_states
norm_hidden_states = self.norm2(hidden_states)
ff_hidden_states = self.ff(norm_hidden_states)
hidden_states = hidden_states + ff_hidden_states
return hidden_states
class LTX2ConnectorTransformer1d(nn.Module):
"""
A 1D sequence transformer for modalities such as text.
In LTX 2.0, this is used to process the text encoder hidden states for each of the video and audio streams.
"""
_supports_gradient_checkpointing = True
def __init__(
self,
num_attention_heads: int = 30,
attention_head_dim: int = 128,
num_layers: int = 2,
num_learnable_registers: Optional[int] = 128,
rope_base_seq_len: int = 4096,
rope_theta: float = 10000.0,
rope_double_precision: bool = True,
eps: float = 1e-6,
causal_temporal_positioning: bool = False,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.inner_dim = num_attention_heads * attention_head_dim
self.causal_temporal_positioning = causal_temporal_positioning
self.num_learnable_registers = num_learnable_registers
self.learnable_registers = None
if num_learnable_registers is not None:
init_registers = torch.rand(num_learnable_registers, self.inner_dim) * 2.0 - 1.0
self.learnable_registers = torch.nn.Parameter(init_registers)
self.rope = LTX2RotaryPosEmbed1d(
self.inner_dim, base_seq_len=rope_base_seq_len, theta=rope_theta, double_precision=rope_double_precision
)
self.transformer_blocks = torch.nn.ModuleList(
[
LTX2TransformerBlock1d(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for _ in range(num_layers)
]
)
self.norm_out = torch.nn.RMSNorm(self.inner_dim, eps=eps, elementwise_affine=False)
self.gradient_checkpointing = False
def forward(
self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# hidden_states shape: [batch_size, seq_len, hidden_dim]
# attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len]
batch_size, seq_len, _ = hidden_states.shape
# 1. Replace padding with learned registers, if using
if self.learnable_registers is not None:
if seq_len % self.num_learnable_registers != 0:
raise ValueError(
f"The `hidden_states` sequence length {hidden_states.shape[1]} should be divisible by the number"
f" of learnable registers {self.num_learnable_registers}"
)
num_register_repeats = seq_len // self.num_learnable_registers
registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim]
binary_attn_mask = (attention_mask >= -9000.0).int()
if binary_attn_mask.ndim == 4:
binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L]
hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)]
valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded]
pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens]
padded_hidden_states = [
F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths)
]
padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D]
flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1]
hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers
# Overwrite attention_mask with an all-zeros mask if using registers.
attention_mask = torch.zeros_like(attention_mask)
# 2. Calculate 1D RoPE positional embeddings
rotary_emb = self.rope(batch_size, seq_len, device=hidden_states.device)
# 3. Run 1D transformer blocks
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(block, hidden_states, attention_mask, rotary_emb)
else:
hidden_states = block(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb)
hidden_states = self.norm_out(hidden_states)
return hidden_states, attention_mask
class LTX2AudioVisualTextEncoder(ModelMixin, ConfigMixin):
ignore_for_config = ["text_model"]
@register_to_config
def __init__(
self,
text_model: Optional[Gemma3ForConditionalGeneration] = None,
text_model_id: str = "google/gemma-3-12b-it-qat-q4_0-unquantized",
text_encoder_hidden_dim: Optional[int] = 3840,
text_proj_in_factor: Optional[int] = 49, # Num layers in text encoder + 1
video_connector_num_attention_heads: int = 30,
video_connector_attention_head_dim: int = 128,
video_connector_num_layers: int = 2,
video_connector_num_learnable_registers: int = 128,
audio_connector_num_attention_heads: int = 30,
audio_connector_attention_head_dim: int = 128,
audio_connector_num_layers: int = 2,
audio_connector_num_learnable_registers: Optional[int] = 128,
rope_base_seq_len: int = 4096,
rope_theta: float = 10000.0,
rope_double_precision: bool = True,
causal_temporal_positioning: bool = False,
config_only: bool = True,
):
super().__init__()
if text_model is None:
self.set_base_text_encoder(text_model_id, config_only=config_only)
else:
self.base_text_encoder = text_model
if text_encoder_hidden_dim is None:
if hasattr(self.base_text_encoder, "config"):
if hasattr(self.base_text_encoder.config, "hidden_size"):
text_encoder_hidden_dim = getattr(self.base_text_encoder.config, "hidden_size", None)
elif hasattr(self.base_text_encoder.config, "text_config"):
text_encoder_hidden_dim = getattr(self.base_text_encoder.config.text_config, "hidden_size", None)
if text_encoder_hidden_dim is None:
raise ValueError(
"`text_encoder_hidden_dim` is `None` and it cannot be inferred, please provide a value for it."
)
if text_proj_in_factor is None:
num_layers = None
if hasattr(self.base_text_encoder, "config"):
if hasattr(self.base_text_encoder.config, "num_hidden_layers"):
num_layers = getattr(self.base_text_encoder.config, "num_hidden_layers", None)
elif hasattr(self.base_text_encoder.config, "text_config"):
num_layers = getattr(self.base_text_encoder.config.text_config, "num_hidden_layers", None)
if num_layers is None:
raise ValueError(
"`text_proj_in_factor` is `None` and it cannot be inferred, please provide a value for it."
)
text_proj_in_factor = num_layers + 1
self.text_proj_in = nn.Linear(
text_encoder_hidden_dim * text_proj_in_factor, text_encoder_hidden_dim, bias=False
)
self.video_connector = LTX2ConnectorTransformer1d(
num_attention_heads=video_connector_num_attention_heads,
attention_head_dim=video_connector_attention_head_dim,
num_layers=video_connector_num_layers,
num_learnable_registers=video_connector_num_learnable_registers,
rope_base_seq_len=rope_base_seq_len,
rope_theta=rope_theta,
rope_double_precision=rope_double_precision,
causal_temporal_positioning=causal_temporal_positioning,
)
self.audio_connector = LTX2ConnectorTransformer1d(
num_attention_heads=audio_connector_num_attention_heads,
attention_head_dim=audio_connector_attention_head_dim,
num_layers=audio_connector_num_layers,
num_learnable_registers=audio_connector_num_learnable_registers,
rope_base_seq_len=rope_base_seq_len,
rope_theta=rope_theta,
rope_double_precision=rope_double_precision,
causal_temporal_positioning=causal_temporal_positioning,
)
def set_base_text_encoder(
self, base_text_encoder_id: str = "google/gemma-3-12b-it-qat-q4_0-unquantized", config_only: bool = True
):
if config_only:
base_text_encoder_config = AutoConfig.from_pretrained(base_text_encoder_id)
base_text_encoder = AutoModel.from_config(base_text_encoder_config)
else:
base_text_encoder = AutoModel.from_pretrained(base_text_encoder_id)
self.base_text_encoder = base_text_encoder
@staticmethod
def pack_text_embeds(
text_hidden_states: torch.Tensor,
sequence_lengths: torch.Tensor,
device: Union[str, torch.device],
padding_side: str = "left",
scale_factor: int = 8,
eps: float = 1e-6,
) -> torch.Tensor:
"""
Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and
per-layer in a masked fashion (only over non-padded positions).
Args:
text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
The number of valid (non-padded) tokens for each batch instance.
device: (`str` or `torch.device`, *optional*):
torch device to place the resulting embeddings on
padding_side: (`str`, *optional*, defaults to `"left"`):
Whether the text tokenizer performs padding on the `"left"` or `"right"`.
scale_factor (`int`, *optional*, defaults to `8`):
Scaling factor to multiply the normalized hidden states by.
eps (`float`, *optional*, defaults to `1e-6`):
A small positive value for numerical stability when performing normalization.
Returns:
`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
Normed and flattened text encoder hidden states.
"""
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
original_dtype = text_hidden_states.dtype
# Create padding mask
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
if padding_side == "right":
# For right padding, valid tokens are from 0 to sequence_length-1
mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len]
elif padding_side == "left":
# For left padding, valid tokens are from (T - sequence_length) to T-1
start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1]
mask = token_indices >= start_indices # [B, T]
else:
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1]
# Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len)
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
# Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len)
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
# Normalization
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
normalized_hidden_states = normalized_hidden_states * scale_factor
# Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers)
normalized_hidden_states = normalized_hidden_states.flatten(2)
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
return normalized_hidden_states
def run_connectors(
self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Run LTX 2.0-specific text embedding post-processing logic on top of the base text encoder hidden_states.
Args:
text_encoder_hidden_states (`torch.Tensor`):
Text encoder packed hidden_states of shape `(batch_size, seq_len, hidden_dim * (num_layers + 1))`.
attention_mask (`torch.Tensor`):
Attention mask of shape `(batch_size, seq_len)`.
Returns:
`Tuple(torch.Tensor, torch.Tensor, torch.Tensor)]`:
Returns a 3-tuple of tensors where the first element is the video text embeddings of shape
`(batch_size, seq_len, hidden_dim)`, the second element is the audio text embeddings of shape
`(batch_size, seq_len, hidden_dim)`, and the third element is an attention mask of shape
`(batch_size, seq_len)`.
"""
# Convert to additive attention mask
text_dtype = text_encoder_hidden_states.dtype
connector_attn_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
connector_attn_mask = connector_attn_mask.to(text_dtype) * torch.finfo(text_dtype).max
text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states)
video_text_embedding, new_attn_mask = self.video_connector(
text_encoder_hidden_states, connector_attn_mask
)
attn_mask = (new_attn_mask < 1e-6).to(torch.int64)
attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1)
video_text_embedding = video_text_embedding * attn_mask
new_attn_mask = attn_mask.squeeze(-1)
audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, connector_attn_mask)
return video_text_embedding, audio_text_embedding, new_attn_mask
def forward(
self,
text_input_ids,
attention_mask: Optional[torch.Tensor] = None,
padding_side: str = "left",
scale_factor: int = 8,
):
text_encoder_outputs = self.base_text_encoder(
input_ids=text_input_ids, attention_mask=attention_mask, output_hidden_states=True
)
text_encoder_hidden_states = text_encoder_outputs.hidden_states
text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
sequence_lengths = attention_mask.sum(dim=-1)
text_encoder_hidden_states = self.pack_text_embeds(
text_encoder_hidden_states,
sequence_lengths,
device=text_encoder_hidden_states.device,
padding_side=padding_side,
scale_factor=scale_factor,
)
video_text_embedding, audio_text_embedding, new_attn_mask = self.run_connectors(
text_encoder_hidden_states, attention_mask
)
return video_text_embedding, audio_text_embedding, new_attn_mask

View File

@@ -58,7 +58,7 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
audio_encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device)
timestep = torch.rand((batch_size,)).to(torch_device)
timestep = torch.rand((batch_size,)).to(torch_device) * 1000
return {
"hidden_states": hidden_states,
@@ -121,7 +121,7 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
sampling_rate = 16000.0
hop_length = 160.0
sigma = torch.rand((1,), generator=torch.manual_seed(seed), dtype=dtype, device="cpu")
sigma = torch.rand((1,), generator=torch.manual_seed(seed), dtype=dtype, device="cpu") * 1000
timestep = (sigma * torch.ones((batch_size,), dtype=dtype, device="cpu")).to(device=torch_device)
num_channels = 4