mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Merge pull request #4 from huggingface/ltx-2-t2v-pipeline
LTX 2.0 Text-to-Video (T2V) Pipeline
This commit is contained in:
@@ -7,9 +7,17 @@ import safetensors.torch
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from diffusers import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, LTX2VideoTransformer3DModel
|
||||
from diffusers import (
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LTX2Pipeline,
|
||||
LTX2VideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
from diffusers.pipelines.ltx2.text_encoder import LTX2AudioVisualTextEncoder
|
||||
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
|
||||
|
||||
|
||||
@@ -71,6 +79,15 @@ LTX_2_0_VOCODER_RENAME_DICT = {
|
||||
"conv_post": "conv_out",
|
||||
}
|
||||
|
||||
LTX_2_0_TEXT_ENCODER_RENAME_DICT = {
|
||||
"video_embeddings_connector": "video_connector",
|
||||
"audio_embeddings_connector": "audio_connector",
|
||||
"transformer_1d_blocks": "transformer_blocks",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
|
||||
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
@@ -125,6 +142,8 @@ LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {
|
||||
|
||||
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
LTX_2_0_TEXT_ENCODER_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
|
||||
def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "test":
|
||||
@@ -163,7 +182,10 @@ def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str,
|
||||
"attention_bias": True,
|
||||
"attention_out_bias": True,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": False,
|
||||
"causal_offset": 1,
|
||||
"timestep_scale_multiplier": 1000,
|
||||
"cross_attn_timestep_scale_multiplier": 1,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
|
||||
@@ -203,7 +225,10 @@ def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str,
|
||||
"attention_bias": True,
|
||||
"attention_out_bias": True,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": True,
|
||||
"causal_offset": 1,
|
||||
"timestep_scale_multiplier": 1000,
|
||||
"cross_attn_timestep_scale_multiplier": 1,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
|
||||
@@ -442,6 +467,82 @@ def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> D
|
||||
return vocoder
|
||||
|
||||
|
||||
def get_ltx2_text_encoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"text_encoder_hidden_dim": 3840,
|
||||
"text_proj_in_factor": 49,
|
||||
"video_connector_num_attention_heads": 30,
|
||||
"video_connector_attention_head_dim": 128,
|
||||
"video_connector_num_layers": 2,
|
||||
"video_connector_num_learnable_registers": 128,
|
||||
"audio_connector_num_attention_heads": 30,
|
||||
"audio_connector_attention_head_dim": 128,
|
||||
"audio_connector_num_layers": 2,
|
||||
"audio_connector_num_learnable_registers": 128,
|
||||
"rope_base_seq_len": 4096,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": True,
|
||||
"causal_temporal_positioning": False,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_TEXT_ENCODER_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_TEXT_ENCODER_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def get_text_encoder_keys_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str = "model.diffusion_model."):
|
||||
model_state_dict = {}
|
||||
|
||||
model_state_dict["text_proj_in.weight"] = combined_ckpt["text_embedding_projection.aggregate_embed.weight"]
|
||||
|
||||
text_encoder_submodules = ["video_embeddings_connector", "audio_embeddings_connector"]
|
||||
for param_name, param in combined_ckpt.items():
|
||||
if param_name.startswith(prefix):
|
||||
new_param_name = param_name.replace(prefix, "")
|
||||
for submodule_name in text_encoder_submodules:
|
||||
if new_param_name.startswith(submodule_name):
|
||||
model_state_dict[new_param_name] = param
|
||||
break
|
||||
|
||||
return model_state_dict
|
||||
|
||||
|
||||
def convert_ltx2_text_encoder(original_state_dict: Dict[str, Any], version: str, text_model_id: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_text_encoder_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
diffusers_config["text_model_id"] = text_model_id
|
||||
diffusers_config["config_only"] = True
|
||||
|
||||
with init_empty_weights():
|
||||
text_encoder = LTX2AudioVisualTextEncoder.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
base_text_model = AutoModel.from_pretrained(text_model_id)
|
||||
base_text_model_state_dict= base_text_model.state_dict()
|
||||
base_text_model_state_dict = {"base_text_encoder." + k: v for k, v in base_text_model_state_dict.items()}
|
||||
combined_state_dict = {**original_state_dict, **base_text_model_state_dict}
|
||||
|
||||
text_encoder.load_state_dict(combined_state_dict, strict=True, assign=True)
|
||||
return text_encoder
|
||||
|
||||
|
||||
def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]:
|
||||
if args.original_state_dict_repo_id is not None:
|
||||
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
|
||||
@@ -528,11 +629,24 @@ def get_args():
|
||||
parser.add_argument(
|
||||
"--vocoder_filename", default=None, type=str, help="Vocoder filename; overrides combined ckpt if set"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_encoder_model_id",
|
||||
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
|
||||
type=str,
|
||||
help="HF Hub id for the LTX 2.0 base text encoder model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_id",
|
||||
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
|
||||
type=str,
|
||||
help="HF Hub id for the LTX 2.0 text tokenizer",
|
||||
)
|
||||
|
||||
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
|
||||
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
|
||||
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
|
||||
parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model")
|
||||
parser.add_argument("--text_encoder", action="store_true", help="Whether to conver the text encoder")
|
||||
parser.add_argument(
|
||||
"--full_pipeline",
|
||||
action="store_true",
|
||||
@@ -543,6 +657,7 @@ def get_args():
|
||||
parser.add_argument("--audio_vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--dit_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--vocoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--text_encoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
|
||||
@@ -567,9 +682,12 @@ def main(args):
|
||||
audio_vae_dtype = DTYPE_MAPPING[args.audio_vae_dtype]
|
||||
dit_dtype = DTYPE_MAPPING[args.dit_dtype]
|
||||
vocoder_dtype = DTYPE_MAPPING[args.vocoder_dtype]
|
||||
text_encoder_dtype = DTYPE_MAPPING[args.text_encoder_dtype]
|
||||
|
||||
combined_ckpt = None
|
||||
load_combined_models = any([args.vae, args.audio_vae, args.dit, args.vocoder, args.full_pipeline])
|
||||
load_combined_models = any(
|
||||
[args.vae, args.audio_vae, args.dit, args.vocoder, args.text_encoder, args.full_pipeline]
|
||||
)
|
||||
if args.combined_filename is not None and load_combined_models:
|
||||
combined_ckpt = load_original_checkpoint(args, filename=args.combined_filename)
|
||||
|
||||
@@ -609,8 +727,37 @@ def main(args):
|
||||
if not args.full_pipeline:
|
||||
vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder"))
|
||||
|
||||
if args.text_encoder or args.full_pipeline:
|
||||
text_encoder_ckpt = get_text_encoder_keys_from_combined_ckpt(combined_ckpt)
|
||||
text_encoder = convert_ltx2_text_encoder(text_encoder_ckpt, args.version, args.text_encoder_model_id)
|
||||
if not args.full_pipeline:
|
||||
text_encoder.to(text_encoder_dtype).save_pretrained(os.path.join(args.output_path, "text_encoder"))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id)
|
||||
if not args.full_pipeline:
|
||||
tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer"))
|
||||
|
||||
if args.full_pipeline:
|
||||
pass
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
|
||||
pipe = LTX2Pipeline(
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
audio_vae=audio_vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
vocoder=vocoder,
|
||||
)
|
||||
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
215
scripts/ltx2_test_full_pipeline.py
Normal file
215
scripts/ltx2_test_full_pipeline.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import argparse
|
||||
import os
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
|
||||
import av # Needs to be installed separately (`pip install av`)
|
||||
import torch
|
||||
|
||||
from diffusers import LTX2Pipeline
|
||||
|
||||
|
||||
# Video export functions copied from original LTX 2.0 code
|
||||
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
|
||||
"""
|
||||
Prepare the audio stream for writing.
|
||||
"""
|
||||
audio_stream = container.add_stream("aac", rate=audio_sample_rate)
|
||||
audio_stream.codec_context.sample_rate = audio_sample_rate
|
||||
audio_stream.codec_context.layout = "stereo"
|
||||
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
|
||||
return audio_stream
|
||||
|
||||
|
||||
def _resample_audio(
|
||||
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
|
||||
) -> None:
|
||||
cc = audio_stream.codec_context
|
||||
|
||||
# Use the encoder's format/layout/rate as the *target*
|
||||
target_format = cc.format or "fltp" # AAC → usually fltp
|
||||
target_layout = cc.layout or "stereo"
|
||||
target_rate = cc.sample_rate or frame_in.sample_rate
|
||||
|
||||
audio_resampler = av.audio.resampler.AudioResampler(
|
||||
format=target_format,
|
||||
layout=target_layout,
|
||||
rate=target_rate,
|
||||
)
|
||||
|
||||
audio_next_pts = 0
|
||||
for rframe in audio_resampler.resample(frame_in):
|
||||
if rframe.pts is None:
|
||||
rframe.pts = audio_next_pts
|
||||
audio_next_pts += rframe.samples
|
||||
rframe.sample_rate = frame_in.sample_rate
|
||||
container.mux(audio_stream.encode(rframe))
|
||||
|
||||
# flush audio encoder
|
||||
for packet in audio_stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
|
||||
def _write_audio(
|
||||
container: av.container.Container,
|
||||
audio_stream: av.audio.AudioStream,
|
||||
samples: torch.Tensor,
|
||||
audio_sample_rate: int,
|
||||
) -> None:
|
||||
if samples.ndim == 1:
|
||||
samples = samples[:, None]
|
||||
|
||||
if samples.shape[1] != 2 and samples.shape[0] == 2:
|
||||
samples = samples.T
|
||||
|
||||
if samples.shape[1] != 2:
|
||||
raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
|
||||
|
||||
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
|
||||
if samples.dtype != torch.int16:
|
||||
samples = torch.clip(samples, -1.0, 1.0)
|
||||
samples = (samples * 32767.0).to(torch.int16)
|
||||
|
||||
frame_in = av.AudioFrame.from_ndarray(
|
||||
samples.contiguous().reshape(1, -1).cpu().numpy(),
|
||||
format="s16",
|
||||
layout="stereo",
|
||||
)
|
||||
frame_in.sample_rate = audio_sample_rate
|
||||
|
||||
_resample_audio(container, audio_stream, frame_in)
|
||||
|
||||
|
||||
def encode_video(
|
||||
video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str
|
||||
) -> None:
|
||||
video_np = video.cpu().numpy()
|
||||
|
||||
_, height, width, _ = video_np.shape
|
||||
|
||||
container = av.open(output_path, mode="w")
|
||||
stream = container.add_stream("libx264", rate=int(fps))
|
||||
stream.width = width
|
||||
stream.height = height
|
||||
stream.pix_fmt = "yuv420p"
|
||||
|
||||
if audio is not None:
|
||||
if audio_sample_rate is None:
|
||||
raise ValueError("audio_sample_rate is required when audio is provided")
|
||||
|
||||
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
|
||||
|
||||
for frame_array in video_np:
|
||||
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
|
||||
for packet in stream.encode(frame):
|
||||
container.mux(packet)
|
||||
|
||||
# Flush encoder
|
||||
for packet in stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
if audio is not None:
|
||||
_write_audio(container, audio_stream, audio, audio_sample_rate)
|
||||
|
||||
container.close()
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--model_id", type=str, default="diffusers-internal-dev/new-ltx-model")
|
||||
parser.add_argument("--revision", type=str, default="main")
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="A video of a dog dancing to energetic electronic dance music",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--negative_prompt",
|
||||
type=str,
|
||||
default=(
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||
"off-sync audio,incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument("--num_inference_steps", type=int, default=40)
|
||||
parser.add_argument("--height", type=int, default=512)
|
||||
parser.add_argument("--width", type=int, default=768)
|
||||
parser.add_argument("--num_frames", type=int, default=121)
|
||||
parser.add_argument("--frame_rate", type=float, default=25.0)
|
||||
parser.add_argument("--guidance_scale", type=float, default=3.0)
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
|
||||
parser.add_argument("--device", type=str, default="cuda:0")
|
||||
parser.add_argument("--dtype", type=str, default="bf16")
|
||||
parser.add_argument("--cpu_offload", action="store_true")
|
||||
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="/home/daniel_gu/samples",
|
||||
help="Output directory for generated video",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_filename",
|
||||
type=str,
|
||||
default="ltx2_sample_video.mp4",
|
||||
help="Filename of the exported generated video",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32
|
||||
return args
|
||||
|
||||
|
||||
def main(args):
|
||||
pipeline = LTX2Pipeline.from_pretrained(
|
||||
args.model_id,
|
||||
revision=args.revision,
|
||||
torch_dtype=args.dtype,
|
||||
)
|
||||
pipeline.to(device=args.device)
|
||||
if args.cpu_offload:
|
||||
pipeline.enable_model_cpu_offload()
|
||||
|
||||
video, audio = pipeline(
|
||||
prompt=args.prompt,
|
||||
negative_prompt=args.negative_prompt,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
num_frames=args.num_frames,
|
||||
frame_rate=args.frame_rate,
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
guidance_scale=args.guidance_scale,
|
||||
generator=torch.Generator(device=args.device).manual_seed(args.seed),
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
# Convert video to uint8 (but keep as NumPy array)
|
||||
video = (video * 255).round().astype("uint8")
|
||||
video = torch.from_numpy(video)
|
||||
|
||||
encode_video(
|
||||
video[0],
|
||||
fps=args.frame_rate,
|
||||
audio=audio[0].float().cpu(),
|
||||
audio_sample_rate=pipeline.vocoder.config.output_sampling_rate, # should be 24000
|
||||
output_path=os.path.join(args.output_dir, args.output_filename),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
main(args)
|
||||
@@ -537,6 +537,7 @@ else:
|
||||
"LTXImageToVideoPipeline",
|
||||
"LTXLatentUpsamplePipeline",
|
||||
"LTXPipeline",
|
||||
"LTX2Pipeline",
|
||||
"LucyEditPipeline",
|
||||
"Lumina2Pipeline",
|
||||
"Lumina2Text2ImgPipeline",
|
||||
@@ -1243,6 +1244,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LTXImageToVideoPipeline,
|
||||
LTXLatentUpsamplePipeline,
|
||||
LTXPipeline,
|
||||
LTX2Pipeline,
|
||||
LucyEditPipeline,
|
||||
Lumina2Pipeline,
|
||||
Lumina2Text2ImgPipeline,
|
||||
|
||||
@@ -605,6 +605,10 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
mel_bins=mel_bins,
|
||||
)
|
||||
|
||||
# TODO: calculate programmatically instead of hardcoding
|
||||
self.temporal_compression_ratio = LATENT_DOWNSAMPLE_FACTOR # 4
|
||||
# TODO: confirm whether the mel compression ratio below is correct
|
||||
self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR
|
||||
self.use_slicing = False
|
||||
|
||||
@apply_forward_hook
|
||||
|
||||
@@ -18,6 +18,7 @@ import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -561,6 +562,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
|
||||
theta: float = 10000.0,
|
||||
causal_offset: int = 1,
|
||||
modality: str = "video",
|
||||
double_precision: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -586,6 +588,7 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
|
||||
self.modality = modality
|
||||
if self.modality not in ["video", "audio"]:
|
||||
raise ValueError(f"Modality {modality} is not supported. Supported modalities are `video` and `audio`.")
|
||||
self.double_precision = double_precision
|
||||
|
||||
def prepare_video_coords(
|
||||
self,
|
||||
@@ -779,14 +782,26 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
|
||||
# 4. Create a 1D grid of frequencies for RoPE
|
||||
start = 1.0
|
||||
end = self.theta
|
||||
freqs = self.theta ** torch.linspace(
|
||||
start=math.log(start, self.theta),
|
||||
end=math.log(end, self.theta),
|
||||
steps=self.dim // num_rope_elems,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
freqs = freqs * math.pi / 2.0
|
||||
if self.double_precision:
|
||||
pow_indices = np.power(
|
||||
self.theta,
|
||||
np.linspace(
|
||||
np.log(start) / np.log(self.theta),
|
||||
np.log(end) / np.log(self.theta),
|
||||
self.dim // num_rope_elems,
|
||||
dtype=np.float64,
|
||||
),
|
||||
)
|
||||
freqs = torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32, device=device)
|
||||
else:
|
||||
freqs = self.theta ** torch.linspace(
|
||||
start=math.log(start, self.theta),
|
||||
end=math.log(end, self.theta),
|
||||
steps=self.dim // num_rope_elems,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
freqs = freqs * math.pi / 2.0
|
||||
|
||||
# 5. Tensor-vector outer product between pos ids tensor of shape (B, 3, num_patches) and freqs vector of shape
|
||||
# (self.dim // num_elems,)
|
||||
@@ -885,7 +900,10 @@ class LTX2VideoTransformer3DModel(
|
||||
attention_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
rope_theta: float = 10000.0,
|
||||
rope_double_precision: bool = True,
|
||||
causal_offset: int = 1,
|
||||
timestep_scale_multiplier: int = 1000,
|
||||
cross_attn_timestep_scale_multiplier: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -951,6 +969,7 @@ class LTX2VideoTransformer3DModel(
|
||||
theta=rope_theta,
|
||||
causal_offset=causal_offset,
|
||||
modality="video",
|
||||
double_precision=rope_double_precision,
|
||||
)
|
||||
self.audio_rope = LTX2AudioVideoRotaryPosEmbed(
|
||||
dim=audio_inner_dim,
|
||||
@@ -963,6 +982,7 @@ class LTX2VideoTransformer3DModel(
|
||||
theta=rope_theta,
|
||||
causal_offset=causal_offset,
|
||||
modality="audio",
|
||||
double_precision=rope_double_precision,
|
||||
)
|
||||
|
||||
# Audio-to-Video, Video-to-Audio Cross-Attention
|
||||
@@ -977,6 +997,7 @@ class LTX2VideoTransformer3DModel(
|
||||
theta=rope_theta,
|
||||
causal_offset=causal_offset,
|
||||
modality="video",
|
||||
double_precision=rope_double_precision,
|
||||
)
|
||||
self.cross_attn_audio_rope = LTX2AudioVideoRotaryPosEmbed(
|
||||
dim=audio_cross_attention_dim,
|
||||
@@ -988,6 +1009,7 @@ class LTX2VideoTransformer3DModel(
|
||||
theta=rope_theta,
|
||||
causal_offset=causal_offset,
|
||||
modality="audio",
|
||||
double_precision=rope_double_precision,
|
||||
)
|
||||
|
||||
# 5. Transformer Blocks
|
||||
@@ -1038,8 +1060,6 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_num_frames: Optional[int] = None,
|
||||
video_coords: Optional[torch.Tensor] = None,
|
||||
audio_coords: Optional[torch.Tensor] = None,
|
||||
timestep_scale_multiplier: int = 1000,
|
||||
cross_attn_timestep_scale_multiplier: int = 1,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor:
|
||||
@@ -1109,9 +1129,7 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_hidden_states = self.audio_proj_in(audio_hidden_states)
|
||||
|
||||
# 3. Prepare timestep embeddings and modulation parameters
|
||||
# Scale timestep
|
||||
timestep = timestep * timestep_scale_multiplier
|
||||
timestep_cross_attn_gate_scale_factor = cross_attn_timestep_scale_multiplier / timestep_scale_multiplier
|
||||
timestep_cross_attn_gate_scale_factor = self.config.cross_attn_timestep_scale_multiplier / self.config.timestep_scale_multiplier
|
||||
|
||||
# 3.1. Prepare global modality (video and audio) timestep embedding and modulation parameters
|
||||
# temb is used in the transformer blocks (as expected), while embedded_timestep is used for the output layer
|
||||
|
||||
@@ -288,6 +288,7 @@ else:
|
||||
"LTXConditionPipeline",
|
||||
"LTXLatentUpsamplePipeline",
|
||||
]
|
||||
_import_structure["ltx2"] = ["LTX2Pipeline"]
|
||||
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
|
||||
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
|
||||
_import_structure["lucy"] = ["LucyEditPipeline"]
|
||||
@@ -719,6 +720,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
)
|
||||
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
|
||||
from .ltx2 import LTX2Pipeline
|
||||
from .lucy import LucyEditPipeline
|
||||
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
|
||||
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
|
||||
|
||||
52
src/diffusers/pipelines/ltx2/__init__.py
Normal file
52
src/diffusers/pipelines/ltx2/__init__.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_ltx2"] = ["LTX2Pipeline"]
|
||||
_import_structure["text_encoder"] = ["LTX2AudioVisualTextEncoder"]
|
||||
_import_structure["vocoder"] = ["LTX2Vocoder"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_ltx2 import LTX2Pipeline
|
||||
from .text_encoder import LTX2AudioVisualTextEncoder
|
||||
from .vocoder import LTX2Vocoder
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
1054
src/diffusers/pipelines/ltx2/pipeline_ltx2.py
Normal file
1054
src/diffusers/pipelines/ltx2/pipeline_ltx2.py
Normal file
File diff suppressed because it is too large
Load Diff
23
src/diffusers/pipelines/ltx2/pipeline_output.py
Normal file
23
src/diffusers/pipelines/ltx2/pipeline_output.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class LTX2PipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for LTX pipelines.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
||||
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`.
|
||||
audio (`torch.Tensor`, `np.ndarray`):
|
||||
TODO
|
||||
"""
|
||||
|
||||
frames: torch.Tensor
|
||||
audio: torch.Tensor
|
||||
625
src/diffusers/pipelines/ltx2/text_encoder.py
Normal file
625
src/diffusers/pipelines/ltx2/text_encoder.py
Normal file
@@ -0,0 +1,625 @@
|
||||
# Copyright 2025 The Lightricks team and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoConfig, AutoModel, Gemma3ForConditionalGeneration
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ...models.attention_dispatch import dispatch_attention_fn
|
||||
from ...models.embeddings import get_1d_rotary_pos_embed
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...utils import is_torch_version, logging
|
||||
from ..pipeline_loading_utils import _fetch_class_library_tuple
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def apply_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
||||
cos, sin = freqs
|
||||
x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
return out
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_ltx2.LTX2AudioVideoAttnProcessor
|
||||
class LTX2AudioVideoAttnProcessor:
|
||||
r"""
|
||||
Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0) for the LTX-2.0 model.
|
||||
Compared to the LTX-1.0 model, we allow the RoPE embeddings for the queries and keys to be separate so that we can
|
||||
support audio-to-video (a2v) and video-to-audio (v2a) cross attention.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if is_torch_version("<", "2.0"):
|
||||
raise ValueError(
|
||||
"LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "LTX2Attention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if query_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, query_rotary_emb)
|
||||
key = apply_rotary_emb(key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_ltx2.LTX2Attention
|
||||
class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
|
||||
r"""
|
||||
Attention class for all LTX-2.0 attention layers. Compared to LTX-1.0, this supports specifying the query and key
|
||||
RoPE embeddings separately for audio-to-video (a2v) and video-to-audio (v2a) cross-attention.
|
||||
"""
|
||||
|
||||
_default_processor_cls = LTX2AudioVideoAttnProcessor
|
||||
_available_processors = [LTX2AudioVideoAttnProcessor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
heads: int = 8,
|
||||
kv_heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = True,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
out_bias: bool = True,
|
||||
qk_norm: str = "rms_norm_across_heads",
|
||||
norm_eps: float = 1e-6,
|
||||
norm_elementwise_affine: bool = True,
|
||||
processor=None,
|
||||
):
|
||||
super().__init__()
|
||||
if qk_norm != "rms_norm_across_heads":
|
||||
raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.")
|
||||
|
||||
self.head_dim = dim_head
|
||||
self.inner_dim = dim_head * heads
|
||||
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
|
||||
self.query_dim = query_dim
|
||||
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
self.use_bias = bias
|
||||
self.dropout = dropout
|
||||
self.out_dim = query_dim
|
||||
self.heads = heads
|
||||
|
||||
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
||||
self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
||||
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
||||
self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
||||
self.to_out = torch.nn.ModuleList([])
|
||||
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
self.to_out.append(torch.nn.Dropout(dropout))
|
||||
|
||||
if processor is None:
|
||||
processor = self._default_processor_cls()
|
||||
self.set_processor(processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
||||
)
|
||||
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
||||
hidden_states = self.processor(
|
||||
self, hidden_states, encoder_hidden_states, attention_mask, query_rotary_emb, key_rotary_emb, **kwargs
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTX2RotaryPosEmbed1d(nn.Module):
|
||||
"""
|
||||
1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
base_seq_len: int = 4096,
|
||||
theta: float = 10000.0,
|
||||
double_precision: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.base_seq_len = base_seq_len
|
||||
self.theta = theta
|
||||
self.double_precision = double_precision
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch_size: int,
|
||||
pos: int,
|
||||
device: Union[str, torch.device],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Get 1D position ids
|
||||
grid_1d = torch.arange(pos, dtype=torch.float32, device=device)
|
||||
# Get fractional indices relative to self.base_seq_len
|
||||
grid_1d = grid_1d / self.base_seq_len
|
||||
grid = grid_1d.unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len]
|
||||
|
||||
# 2. Calculate 1D RoPE frequencies
|
||||
num_rope_elems = 2 # 1 (because 1D) * 2 (for cos, sin) = 2
|
||||
start = 1.0
|
||||
end = self.theta
|
||||
if self.double_precision:
|
||||
pow_indices = np.power(
|
||||
self.theta,
|
||||
np.linspace(
|
||||
np.log(start) / np.log(self.theta),
|
||||
np.log(end) / np.log(self.theta),
|
||||
self.dim // num_rope_elems,
|
||||
dtype=np.float64,
|
||||
),
|
||||
)
|
||||
freqs = torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32, device=device)
|
||||
else:
|
||||
freqs = self.theta ** torch.linspace(
|
||||
start=math.log(start, self.theta),
|
||||
end=math.log(end, self.theta),
|
||||
steps=self.dim // num_rope_elems,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
freqs = freqs * math.pi / 2.0
|
||||
|
||||
# 3. Matrix-vector outer product between pos ids of shape (batch_size, seq_len) and freqs vector of shape
|
||||
# (self.dim // 2,).
|
||||
freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, seq_len, self.dim // 2]
|
||||
|
||||
# 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim
|
||||
cos_freqs = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_freqs = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
|
||||
if self.dim % num_rope_elems != 0:
|
||||
cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems])
|
||||
sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems])
|
||||
cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
|
||||
sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
|
||||
|
||||
return cos_freqs, sin_freqs
|
||||
|
||||
|
||||
class LTX2TransformerBlock1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.attn1 = LTX2Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
kv_heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
processor=LTX2AudioVideoAttnProcessor(),
|
||||
)
|
||||
|
||||
self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.ff = FeedForward(dim, activation_fn=activation_fn)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
attn_hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, query_rotary_emb=rotary_emb)
|
||||
hidden_states = hidden_states + attn_hidden_states
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
ff_hidden_states = self.ff(norm_hidden_states)
|
||||
hidden_states = hidden_states + ff_hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTX2ConnectorTransformer1d(nn.Module):
|
||||
"""
|
||||
A 1D sequence transformer for modalities such as text.
|
||||
|
||||
In LTX 2.0, this is used to process the text encoder hidden states for each of the video and audio streams.
|
||||
"""
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 30,
|
||||
attention_head_dim: int = 128,
|
||||
num_layers: int = 2,
|
||||
num_learnable_registers: Optional[int] = 128,
|
||||
rope_base_seq_len: int = 4096,
|
||||
rope_theta: float = 10000.0,
|
||||
rope_double_precision: bool = True,
|
||||
eps: float = 1e-6,
|
||||
causal_temporal_positioning: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.causal_temporal_positioning = causal_temporal_positioning
|
||||
|
||||
self.num_learnable_registers = num_learnable_registers
|
||||
self.learnable_registers = None
|
||||
if num_learnable_registers is not None:
|
||||
init_registers = torch.rand(num_learnable_registers, self.inner_dim) * 2.0 - 1.0
|
||||
self.learnable_registers = torch.nn.Parameter(init_registers)
|
||||
|
||||
self.rope = LTX2RotaryPosEmbed1d(
|
||||
self.inner_dim, base_seq_len=rope_base_seq_len, theta=rope_theta, double_precision=rope_double_precision
|
||||
)
|
||||
|
||||
self.transformer_blocks = torch.nn.ModuleList(
|
||||
[
|
||||
LTX2TransformerBlock1d(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = torch.nn.RMSNorm(self.inner_dim, eps=eps, elementwise_affine=False)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# hidden_states shape: [batch_size, seq_len, hidden_dim]
|
||||
# attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len]
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
|
||||
# 1. Replace padding with learned registers, if using
|
||||
if self.learnable_registers is not None:
|
||||
if seq_len % self.num_learnable_registers != 0:
|
||||
raise ValueError(
|
||||
f"The `hidden_states` sequence length {hidden_states.shape[1]} should be divisible by the number"
|
||||
f" of learnable registers {self.num_learnable_registers}"
|
||||
)
|
||||
|
||||
num_register_repeats = seq_len // self.num_learnable_registers
|
||||
registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim]
|
||||
|
||||
binary_attn_mask = (attention_mask >= -9000.0).int()
|
||||
if binary_attn_mask.ndim == 4:
|
||||
binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L]
|
||||
|
||||
hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)]
|
||||
valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded]
|
||||
pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens]
|
||||
padded_hidden_states = [
|
||||
F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths)
|
||||
]
|
||||
padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D]
|
||||
|
||||
flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1]
|
||||
hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers
|
||||
|
||||
# Overwrite attention_mask with an all-zeros mask if using registers.
|
||||
attention_mask = torch.zeros_like(attention_mask)
|
||||
|
||||
# 2. Calculate 1D RoPE positional embeddings
|
||||
rotary_emb = self.rope(batch_size, seq_len, device=hidden_states.device)
|
||||
|
||||
# 3. Run 1D transformer blocks
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(block, hidden_states, attention_mask, rotary_emb)
|
||||
else:
|
||||
hidden_states = block(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
|
||||
return hidden_states, attention_mask
|
||||
|
||||
|
||||
class LTX2AudioVisualTextEncoder(ModelMixin, ConfigMixin):
|
||||
ignore_for_config = ["text_model"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
text_model: Optional[Gemma3ForConditionalGeneration] = None,
|
||||
text_model_id: str = "google/gemma-3-12b-it-qat-q4_0-unquantized",
|
||||
text_encoder_hidden_dim: Optional[int] = 3840,
|
||||
text_proj_in_factor: Optional[int] = 49, # Num layers in text encoder + 1
|
||||
video_connector_num_attention_heads: int = 30,
|
||||
video_connector_attention_head_dim: int = 128,
|
||||
video_connector_num_layers: int = 2,
|
||||
video_connector_num_learnable_registers: int = 128,
|
||||
audio_connector_num_attention_heads: int = 30,
|
||||
audio_connector_attention_head_dim: int = 128,
|
||||
audio_connector_num_layers: int = 2,
|
||||
audio_connector_num_learnable_registers: Optional[int] = 128,
|
||||
rope_base_seq_len: int = 4096,
|
||||
rope_theta: float = 10000.0,
|
||||
rope_double_precision: bool = True,
|
||||
causal_temporal_positioning: bool = False,
|
||||
config_only: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
if text_model is None:
|
||||
self.set_base_text_encoder(text_model_id, config_only=config_only)
|
||||
else:
|
||||
self.base_text_encoder = text_model
|
||||
|
||||
if text_encoder_hidden_dim is None:
|
||||
if hasattr(self.base_text_encoder, "config"):
|
||||
if hasattr(self.base_text_encoder.config, "hidden_size"):
|
||||
text_encoder_hidden_dim = getattr(self.base_text_encoder.config, "hidden_size", None)
|
||||
elif hasattr(self.base_text_encoder.config, "text_config"):
|
||||
text_encoder_hidden_dim = getattr(self.base_text_encoder.config.text_config, "hidden_size", None)
|
||||
if text_encoder_hidden_dim is None:
|
||||
raise ValueError(
|
||||
"`text_encoder_hidden_dim` is `None` and it cannot be inferred, please provide a value for it."
|
||||
)
|
||||
|
||||
if text_proj_in_factor is None:
|
||||
num_layers = None
|
||||
if hasattr(self.base_text_encoder, "config"):
|
||||
if hasattr(self.base_text_encoder.config, "num_hidden_layers"):
|
||||
num_layers = getattr(self.base_text_encoder.config, "num_hidden_layers", None)
|
||||
elif hasattr(self.base_text_encoder.config, "text_config"):
|
||||
num_layers = getattr(self.base_text_encoder.config.text_config, "num_hidden_layers", None)
|
||||
if num_layers is None:
|
||||
raise ValueError(
|
||||
"`text_proj_in_factor` is `None` and it cannot be inferred, please provide a value for it."
|
||||
)
|
||||
text_proj_in_factor = num_layers + 1
|
||||
|
||||
self.text_proj_in = nn.Linear(
|
||||
text_encoder_hidden_dim * text_proj_in_factor, text_encoder_hidden_dim, bias=False
|
||||
)
|
||||
|
||||
self.video_connector = LTX2ConnectorTransformer1d(
|
||||
num_attention_heads=video_connector_num_attention_heads,
|
||||
attention_head_dim=video_connector_attention_head_dim,
|
||||
num_layers=video_connector_num_layers,
|
||||
num_learnable_registers=video_connector_num_learnable_registers,
|
||||
rope_base_seq_len=rope_base_seq_len,
|
||||
rope_theta=rope_theta,
|
||||
rope_double_precision=rope_double_precision,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
)
|
||||
self.audio_connector = LTX2ConnectorTransformer1d(
|
||||
num_attention_heads=audio_connector_num_attention_heads,
|
||||
attention_head_dim=audio_connector_attention_head_dim,
|
||||
num_layers=audio_connector_num_layers,
|
||||
num_learnable_registers=audio_connector_num_learnable_registers,
|
||||
rope_base_seq_len=rope_base_seq_len,
|
||||
rope_theta=rope_theta,
|
||||
rope_double_precision=rope_double_precision,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
)
|
||||
|
||||
def set_base_text_encoder(
|
||||
self, base_text_encoder_id: str = "google/gemma-3-12b-it-qat-q4_0-unquantized", config_only: bool = True
|
||||
):
|
||||
if config_only:
|
||||
base_text_encoder_config = AutoConfig.from_pretrained(base_text_encoder_id)
|
||||
base_text_encoder = AutoModel.from_config(base_text_encoder_config)
|
||||
else:
|
||||
base_text_encoder = AutoModel.from_pretrained(base_text_encoder_id)
|
||||
self.base_text_encoder = base_text_encoder
|
||||
|
||||
@staticmethod
|
||||
def pack_text_embeds(
|
||||
text_hidden_states: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
device: Union[str, torch.device],
|
||||
padding_side: str = "left",
|
||||
scale_factor: int = 8,
|
||||
eps: float = 1e-6,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and
|
||||
per-layer in a masked fashion (only over non-padded positions).
|
||||
|
||||
Args:
|
||||
text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
|
||||
Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
|
||||
sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
|
||||
The number of valid (non-padded) tokens for each batch instance.
|
||||
device: (`str` or `torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
padding_side: (`str`, *optional*, defaults to `"left"`):
|
||||
Whether the text tokenizer performs padding on the `"left"` or `"right"`.
|
||||
scale_factor (`int`, *optional*, defaults to `8`):
|
||||
Scaling factor to multiply the normalized hidden states by.
|
||||
eps (`float`, *optional*, defaults to `1e-6`):
|
||||
A small positive value for numerical stability when performing normalization.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
|
||||
Normed and flattened text encoder hidden states.
|
||||
"""
|
||||
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
|
||||
original_dtype = text_hidden_states.dtype
|
||||
|
||||
# Create padding mask
|
||||
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
|
||||
if padding_side == "right":
|
||||
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||
mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len]
|
||||
elif padding_side == "left":
|
||||
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||
start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1]
|
||||
mask = token_indices >= start_indices # [B, T]
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1]
|
||||
|
||||
# Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len)
|
||||
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
|
||||
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
|
||||
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
|
||||
|
||||
# Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len)
|
||||
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
|
||||
# Normalization
|
||||
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
|
||||
normalized_hidden_states = normalized_hidden_states * scale_factor
|
||||
|
||||
# Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.flatten(2)
|
||||
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
|
||||
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
|
||||
return normalized_hidden_states
|
||||
|
||||
def run_connectors(
|
||||
self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Run LTX 2.0-specific text embedding post-processing logic on top of the base text encoder hidden_states.
|
||||
|
||||
Args:
|
||||
text_encoder_hidden_states (`torch.Tensor`):
|
||||
Text encoder packed hidden_states of shape `(batch_size, seq_len, hidden_dim * (num_layers + 1))`.
|
||||
attention_mask (`torch.Tensor`):
|
||||
Attention mask of shape `(batch_size, seq_len)`.
|
||||
|
||||
Returns:
|
||||
`Tuple(torch.Tensor, torch.Tensor, torch.Tensor)]`:
|
||||
Returns a 3-tuple of tensors where the first element is the video text embeddings of shape
|
||||
`(batch_size, seq_len, hidden_dim)`, the second element is the audio text embeddings of shape
|
||||
`(batch_size, seq_len, hidden_dim)`, and the third element is an attention mask of shape
|
||||
`(batch_size, seq_len)`.
|
||||
"""
|
||||
# Convert to additive attention mask
|
||||
text_dtype = text_encoder_hidden_states.dtype
|
||||
connector_attn_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
||||
connector_attn_mask = connector_attn_mask.to(text_dtype) * torch.finfo(text_dtype).max
|
||||
|
||||
text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states)
|
||||
|
||||
video_text_embedding, new_attn_mask = self.video_connector(
|
||||
text_encoder_hidden_states, connector_attn_mask
|
||||
)
|
||||
|
||||
attn_mask = (new_attn_mask < 1e-6).to(torch.int64)
|
||||
attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1)
|
||||
video_text_embedding = video_text_embedding * attn_mask
|
||||
new_attn_mask = attn_mask.squeeze(-1)
|
||||
|
||||
audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, connector_attn_mask)
|
||||
|
||||
return video_text_embedding, audio_text_embedding, new_attn_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text_input_ids,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
padding_side: str = "left",
|
||||
scale_factor: int = 8,
|
||||
):
|
||||
text_encoder_outputs = self.base_text_encoder(
|
||||
input_ids=text_input_ids, attention_mask=attention_mask, output_hidden_states=True
|
||||
)
|
||||
|
||||
text_encoder_hidden_states = text_encoder_outputs.hidden_states
|
||||
text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
|
||||
sequence_lengths = attention_mask.sum(dim=-1)
|
||||
|
||||
text_encoder_hidden_states = self.pack_text_embeds(
|
||||
text_encoder_hidden_states,
|
||||
sequence_lengths,
|
||||
device=text_encoder_hidden_states.device,
|
||||
padding_side=padding_side,
|
||||
scale_factor=scale_factor,
|
||||
)
|
||||
|
||||
video_text_embedding, audio_text_embedding, new_attn_mask = self.run_connectors(
|
||||
text_encoder_hidden_states, attention_mask
|
||||
)
|
||||
|
||||
return video_text_embedding, audio_text_embedding, new_attn_mask
|
||||
@@ -58,7 +58,7 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
audio_encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device)
|
||||
timestep = torch.rand((batch_size,)).to(torch_device)
|
||||
timestep = torch.rand((batch_size,)).to(torch_device) * 1000
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
@@ -121,7 +121,7 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
sampling_rate = 16000.0
|
||||
hop_length = 160.0
|
||||
|
||||
sigma = torch.rand((1,), generator=torch.manual_seed(seed), dtype=dtype, device="cpu")
|
||||
sigma = torch.rand((1,), generator=torch.manual_seed(seed), dtype=dtype, device="cpu") * 1000
|
||||
timestep = (sigma * torch.ones((batch_size,), dtype=dtype, device="cpu")).to(device=torch_device)
|
||||
|
||||
num_channels = 4
|
||||
|
||||
Reference in New Issue
Block a user