mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Move Video and Audio Text Encoder Connectors to Transformer (#12)
* Denormalize audio latents in I2V pipeline (analogous to T2V change) * Initial refactor to put video and audio text encoder connectors in transformer * Get LTX 2 transformer tests working after connector refactor * precompute run_connectors,. * fixes * Address review comments * Calculate RoPE double precisions freqs using torch instead of np * Further simplify LTX 2 RoPE freq calc * Make connectors a separate module (#18) * remove text_encoder.py * address yiyi's comments. * up * up * up * up --------- Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -8,18 +8,11 @@ import safetensors.torch
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from transformers import AutoModel, AutoTokenizer, Gemma3ForConditionalGeneration
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LTX2Pipeline,
|
||||
LTX2VideoTransformer3DModel,
|
||||
)
|
||||
from diffusers import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video, FlowMatchEulerDiscreteScheduler, LTX2Pipeline, LTX2VideoTransformer3DModel
|
||||
from diffusers.pipelines.ltx2 import LTX2TextConnectors, LTX2Vocoder
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
from diffusers.pipelines.ltx2.text_encoder import LTX2AudioVisualTextEncoder
|
||||
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
|
||||
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
@@ -134,6 +127,17 @@ LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"adaln_single": convert_ltx2_transformer_adaln_single,
|
||||
}
|
||||
|
||||
LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = {
|
||||
"connectors.": "",
|
||||
"video_embeddings_connector": "video_connector",
|
||||
"audio_embeddings_connector": "audio_connector",
|
||||
"transformer_1d_blocks": "transformer_blocks",
|
||||
"text_embedding_projection.aggregate_embed": "text_proj_in",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_inplace,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
|
||||
@@ -146,7 +150,27 @@ LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {
|
||||
|
||||
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
LTX_2_0_TEXT_ENCODER_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
def split_transformer_and_connector_state_dict(state_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
connector_prefixes = (
|
||||
"video_embeddings_connector",
|
||||
"audio_embeddings_connector",
|
||||
"transformer_1d_blocks",
|
||||
"text_embedding_projection.aggregate_embed",
|
||||
"connectors.",
|
||||
"video_connector",
|
||||
"audio_connector",
|
||||
"text_proj_in",
|
||||
)
|
||||
|
||||
transformer_state_dict, connector_state_dict = {}, {}
|
||||
for key, value in state_dict.items():
|
||||
if key.startswith(connector_prefixes):
|
||||
connector_state_dict[key] = value
|
||||
else:
|
||||
transformer_state_dict[key] = value
|
||||
|
||||
return transformer_state_dict, connector_state_dict
|
||||
|
||||
|
||||
def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
@@ -240,32 +264,109 @@ def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str,
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def get_ltx2_connectors_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "test":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"diffusers_config": {
|
||||
"caption_channels": 16,
|
||||
"text_proj_in_factor": 3,
|
||||
"video_connector_num_attention_heads": 4,
|
||||
"video_connector_attention_head_dim": 8,
|
||||
"video_connector_num_layers": 1,
|
||||
"video_connector_num_learnable_registers": None,
|
||||
"audio_connector_num_attention_heads": 4,
|
||||
"audio_connector_attention_head_dim": 8,
|
||||
"audio_connector_num_layers": 1,
|
||||
"audio_connector_num_learnable_registers": None,
|
||||
"connector_rope_base_seq_len": 32,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": False,
|
||||
"causal_temporal_positioning": False,
|
||||
},
|
||||
}
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"caption_channels": 3840,
|
||||
"text_proj_in_factor": 49,
|
||||
"video_connector_num_attention_heads": 30,
|
||||
"video_connector_attention_head_dim": 128,
|
||||
"video_connector_num_layers": 2,
|
||||
"video_connector_num_learnable_registers": 128,
|
||||
"audio_connector_num_attention_heads": 30,
|
||||
"audio_connector_attention_head_dim": 128,
|
||||
"audio_connector_num_layers": 2,
|
||||
"audio_connector_num_learnable_registers": 128,
|
||||
"connector_rope_base_seq_len": 4096,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": True,
|
||||
"causal_temporal_positioning": False,
|
||||
},
|
||||
}
|
||||
|
||||
rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
|
||||
special_keys_remap = {}
|
||||
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_transformer(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_transformer_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
transformer_state_dict, _ = split_transformer_and_connector_state_dict(original_state_dict)
|
||||
|
||||
with init_empty_weights():
|
||||
transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
for key in list(transformer_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
update_state_dict_inplace(transformer_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for key in list(transformer_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
handler_fn_inplace(key, transformer_state_dict)
|
||||
|
||||
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
transformer.load_state_dict(transformer_state_dict, strict=True, assign=True)
|
||||
return transformer
|
||||
|
||||
|
||||
def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -> LTX2TextConnectors:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_connectors_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
_, connector_state_dict = split_transformer_and_connector_state_dict(original_state_dict)
|
||||
if len(connector_state_dict) == 0:
|
||||
raise ValueError("No connector weights found in the provided state dict.")
|
||||
|
||||
with init_empty_weights():
|
||||
connectors = LTX2TextConnectors.from_config(diffusers_config)
|
||||
|
||||
for key in list(connector_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(connector_state_dict, key, new_key)
|
||||
|
||||
for key in list(connector_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, connector_state_dict)
|
||||
|
||||
connectors.load_state_dict(connector_state_dict, strict=True, assign=True)
|
||||
return connectors
|
||||
|
||||
|
||||
def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "test":
|
||||
config = {
|
||||
@@ -471,81 +572,6 @@ def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> D
|
||||
return vocoder
|
||||
|
||||
|
||||
def get_ltx2_text_encoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"text_encoder_hidden_dim": 3840,
|
||||
"text_proj_in_factor": 49,
|
||||
"video_connector_num_attention_heads": 30,
|
||||
"video_connector_attention_head_dim": 128,
|
||||
"video_connector_num_layers": 2,
|
||||
"video_connector_num_learnable_registers": 128,
|
||||
"audio_connector_num_attention_heads": 30,
|
||||
"audio_connector_attention_head_dim": 128,
|
||||
"audio_connector_num_layers": 2,
|
||||
"audio_connector_num_learnable_registers": 128,
|
||||
"rope_base_seq_len": 4096,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": True,
|
||||
"causal_temporal_positioning": False,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_TEXT_ENCODER_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_TEXT_ENCODER_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def get_text_encoder_keys_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str = "model.diffusion_model."):
|
||||
model_state_dict = {}
|
||||
|
||||
model_state_dict["text_proj_in.weight"] = combined_ckpt["text_embedding_projection.aggregate_embed.weight"]
|
||||
|
||||
text_encoder_submodules = ["video_embeddings_connector", "audio_embeddings_connector"]
|
||||
for param_name, param in combined_ckpt.items():
|
||||
if param_name.startswith(prefix):
|
||||
new_param_name = param_name.replace(prefix, "")
|
||||
for submodule_name in text_encoder_submodules:
|
||||
if new_param_name.startswith(submodule_name):
|
||||
model_state_dict[new_param_name] = param
|
||||
break
|
||||
|
||||
return model_state_dict
|
||||
|
||||
|
||||
def convert_ltx2_text_encoder(original_state_dict: Dict[str, Any], version: str, text_model_id: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_text_encoder_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
diffusers_config["text_model_id"] = text_model_id
|
||||
diffusers_config["config_only"] = True
|
||||
|
||||
with init_empty_weights():
|
||||
text_encoder = LTX2AudioVisualTextEncoder.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
base_text_model = AutoModel.from_pretrained(text_model_id)
|
||||
base_text_model_state_dict= base_text_model.state_dict()
|
||||
base_text_model_state_dict = {"base_text_encoder." + k: v for k, v in base_text_model_state_dict.items()}
|
||||
combined_state_dict = {**original_state_dict, **base_text_model_state_dict}
|
||||
|
||||
text_encoder.load_state_dict(combined_state_dict, strict=True, assign=True)
|
||||
return text_encoder
|
||||
|
||||
|
||||
def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]:
|
||||
if args.original_state_dict_repo_id is not None:
|
||||
@@ -588,6 +614,13 @@ def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefi
|
||||
for param_name, param in combined_ckpt.items():
|
||||
if param_name.startswith(prefix):
|
||||
model_state_dict[param_name.replace(prefix, "")] = param
|
||||
|
||||
if prefix == "model.diffusion_model.":
|
||||
# Some checkpoints store the text connector projection outside the diffusion model prefix.
|
||||
connector_key = "text_embedding_projection.aggregate_embed.weight"
|
||||
if connector_key in combined_ckpt and connector_key not in model_state_dict:
|
||||
model_state_dict[connector_key] = combined_ckpt[connector_key]
|
||||
|
||||
return model_state_dict
|
||||
|
||||
|
||||
@@ -649,6 +682,7 @@ def get_args():
|
||||
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
|
||||
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
|
||||
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
|
||||
parser.add_argument("--connectors", action="store_true", help="Whether to convert the connector model")
|
||||
parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model")
|
||||
parser.add_argument("--text_encoder", action="store_true", help="Whether to conver the text encoder")
|
||||
parser.add_argument(
|
||||
@@ -721,6 +755,15 @@ def main(args):
|
||||
transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer"))
|
||||
|
||||
if args.connectors or args.full_pipeline:
|
||||
if args.dit_filename is not None:
|
||||
original_connectors_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_connectors_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
|
||||
connectors = convert_ltx2_connectors(original_connectors_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
connectors.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "connectors"))
|
||||
|
||||
if args.vocoder or args.full_pipeline:
|
||||
if args.vocoder_filename is not None:
|
||||
@@ -732,8 +775,8 @@ def main(args):
|
||||
vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder"))
|
||||
|
||||
if args.text_encoder or args.full_pipeline:
|
||||
text_encoder_ckpt = get_text_encoder_keys_from_combined_ckpt(combined_ckpt)
|
||||
text_encoder = convert_ltx2_text_encoder(text_encoder_ckpt, args.version, args.text_encoder_model_id)
|
||||
# text_encoder = AutoModel.from_pretrained(args.text_encoder_model_id)
|
||||
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(args.text_encoder_model_id)
|
||||
if not args.full_pipeline:
|
||||
text_encoder.to(text_encoder_dtype).save_pretrained(os.path.join(args.output_path, "text_encoder"))
|
||||
|
||||
@@ -758,6 +801,7 @@ def main(args):
|
||||
audio_vae=audio_vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
connectors=connectors,
|
||||
transformer=transformer,
|
||||
vocoder=vocoder,
|
||||
)
|
||||
|
||||
@@ -14,11 +14,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -780,28 +778,12 @@ class LTX2AudioVideoRotaryPosEmbed(nn.Module):
|
||||
num_rope_elems = num_pos_dims * 2
|
||||
|
||||
# 4. Create a 1D grid of frequencies for RoPE
|
||||
start = 1.0
|
||||
end = self.theta
|
||||
if self.double_precision:
|
||||
pow_indices = np.power(
|
||||
self.theta,
|
||||
np.linspace(
|
||||
np.log(start) / np.log(self.theta),
|
||||
np.log(end) / np.log(self.theta),
|
||||
self.dim // num_rope_elems,
|
||||
dtype=np.float64,
|
||||
),
|
||||
)
|
||||
freqs = torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32, device=device)
|
||||
else:
|
||||
freqs = self.theta ** torch.linspace(
|
||||
start=math.log(start, self.theta),
|
||||
end=math.log(end, self.theta),
|
||||
steps=self.dim // num_rope_elems,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
freqs = freqs * math.pi / 2.0
|
||||
freqs_dtype = torch.float64 if self.double_precision else torch.float32
|
||||
pow_indices = torch.pow(
|
||||
self.theta,
|
||||
torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device),
|
||||
)
|
||||
freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32)
|
||||
|
||||
# 5. Tensor-vector outer product between pos ids tensor of shape (B, 3, num_patches) and freqs vector of shape
|
||||
# (self.dim // num_elems,)
|
||||
|
||||
@@ -24,7 +24,7 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["pipeline_ltx2"] = ["LTX2Pipeline"]
|
||||
_import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"]
|
||||
_import_structure["text_encoder"] = ["LTX2AudioVisualTextEncoder"]
|
||||
_import_structure["connectors"] = ["LTX2TextConnectors"]
|
||||
_import_structure["vocoder"] = ["LTX2Vocoder"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
@@ -37,7 +37,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .pipeline_ltx2 import LTX2Pipeline
|
||||
from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline
|
||||
from .text_encoder import LTX2AudioVisualTextEncoder
|
||||
from .connectors import LTX2TextConnectors
|
||||
from .vocoder import LTX2Vocoder
|
||||
|
||||
else:
|
||||
|
||||
281
src/diffusers/pipelines/ltx2/connectors.py
Normal file
281
src/diffusers/pipelines/ltx2/connectors.py
Normal file
@@ -0,0 +1,281 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.attention import FeedForward
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor
|
||||
|
||||
|
||||
class LTX2RotaryPosEmbed1d(nn.Module):
|
||||
"""
|
||||
1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
base_seq_len: int = 4096,
|
||||
theta: float = 10000.0,
|
||||
double_precision: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.base_seq_len = base_seq_len
|
||||
self.theta = theta
|
||||
self.double_precision = double_precision
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch_size: int,
|
||||
pos: int,
|
||||
device: Union[str, torch.device],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Get 1D position ids
|
||||
grid_1d = torch.arange(pos, dtype=torch.float32, device=device)
|
||||
# Get fractional indices relative to self.base_seq_len
|
||||
grid_1d = grid_1d / self.base_seq_len
|
||||
grid = grid_1d.unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len]
|
||||
|
||||
# 2. Calculate 1D RoPE frequencies
|
||||
num_rope_elems = 2 # 1 (because 1D) * 2 (for cos, sin) = 2
|
||||
freqs_dtype = torch.float64 if self.double_precision else torch.float32
|
||||
pow_indices = torch.pow(
|
||||
self.theta,
|
||||
torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device),
|
||||
)
|
||||
freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32)
|
||||
|
||||
# 3. Matrix-vector outer product between pos ids of shape (batch_size, seq_len) and freqs vector of shape
|
||||
# (self.dim // 2,).
|
||||
freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, seq_len, self.dim // 2]
|
||||
|
||||
# 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim
|
||||
cos_freqs = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_freqs = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
|
||||
if self.dim % num_rope_elems != 0:
|
||||
cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems])
|
||||
sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems])
|
||||
cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
|
||||
sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
|
||||
|
||||
return cos_freqs, sin_freqs
|
||||
|
||||
|
||||
class LTX2TransformerBlock1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.attn1 = LTX2Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
kv_heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
processor=LTX2AudioVideoAttnProcessor(),
|
||||
)
|
||||
|
||||
self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.ff = FeedForward(dim, activation_fn=activation_fn)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
attn_hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, query_rotary_emb=rotary_emb)
|
||||
hidden_states = hidden_states + attn_hidden_states
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
ff_hidden_states = self.ff(norm_hidden_states)
|
||||
hidden_states = hidden_states + ff_hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTX2ConnectorTransformer1d(nn.Module):
|
||||
"""
|
||||
A 1D sequence transformer for modalities such as text.
|
||||
|
||||
In LTX 2.0, this is used to process the text encoder hidden states for each of the video and audio streams.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 30,
|
||||
attention_head_dim: int = 128,
|
||||
num_layers: int = 2,
|
||||
num_learnable_registers: int | None = 128,
|
||||
rope_base_seq_len: int = 4096,
|
||||
rope_theta: float = 10000.0,
|
||||
rope_double_precision: bool = True,
|
||||
eps: float = 1e-6,
|
||||
causal_temporal_positioning: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.causal_temporal_positioning = causal_temporal_positioning
|
||||
|
||||
self.num_learnable_registers = num_learnable_registers
|
||||
self.learnable_registers = None
|
||||
if num_learnable_registers is not None:
|
||||
init_registers = torch.rand(num_learnable_registers, self.inner_dim) * 2.0 - 1.0
|
||||
self.learnable_registers = torch.nn.Parameter(init_registers)
|
||||
|
||||
self.rope = LTX2RotaryPosEmbed1d(
|
||||
self.inner_dim, base_seq_len=rope_base_seq_len, theta=rope_theta, double_precision=rope_double_precision
|
||||
)
|
||||
|
||||
self.transformer_blocks = torch.nn.ModuleList(
|
||||
[
|
||||
LTX2TransformerBlock1d(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = torch.nn.RMSNorm(self.inner_dim, eps=eps, elementwise_affine=False)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attn_mask_binarize_threshold: float = -9000.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# hidden_states shape: [batch_size, seq_len, hidden_dim]
|
||||
# attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len]
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
|
||||
# 1. Replace padding with learned registers, if using
|
||||
if self.learnable_registers is not None:
|
||||
if seq_len % self.num_learnable_registers != 0:
|
||||
raise ValueError(
|
||||
f"The `hidden_states` sequence length {hidden_states.shape[1]} should be divisible by the number"
|
||||
f" of learnable registers {self.num_learnable_registers}"
|
||||
)
|
||||
|
||||
num_register_repeats = seq_len // self.num_learnable_registers
|
||||
registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim]
|
||||
|
||||
binary_attn_mask = (attention_mask >= attn_mask_binarize_threshold).int()
|
||||
if binary_attn_mask.ndim == 4:
|
||||
binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L]
|
||||
|
||||
hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)]
|
||||
valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded]
|
||||
pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens]
|
||||
padded_hidden_states = [
|
||||
F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths)
|
||||
]
|
||||
padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D]
|
||||
|
||||
flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1]
|
||||
hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers
|
||||
|
||||
# Overwrite attention_mask with an all-zeros mask if using registers.
|
||||
attention_mask = torch.zeros_like(attention_mask)
|
||||
|
||||
# 2. Calculate 1D RoPE positional embeddings
|
||||
rotary_emb = self.rope(batch_size, seq_len, device=hidden_states.device)
|
||||
|
||||
# 3. Run 1D transformer blocks
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(block, hidden_states, attention_mask, rotary_emb)
|
||||
else:
|
||||
hidden_states = block(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
|
||||
return hidden_states, attention_mask
|
||||
|
||||
|
||||
class LTX2TextConnectors(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and
|
||||
audio streams.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
caption_channels: int,
|
||||
text_proj_in_factor: int,
|
||||
video_connector_num_attention_heads: int,
|
||||
video_connector_attention_head_dim: int,
|
||||
video_connector_num_layers: int,
|
||||
video_connector_num_learnable_registers: int | None,
|
||||
audio_connector_num_attention_heads: int,
|
||||
audio_connector_attention_head_dim: int,
|
||||
audio_connector_num_layers: int,
|
||||
audio_connector_num_learnable_registers: int | None,
|
||||
connector_rope_base_seq_len: int,
|
||||
rope_theta: float,
|
||||
rope_double_precision: bool,
|
||||
causal_temporal_positioning: bool,
|
||||
):
|
||||
super().__init__()
|
||||
self.text_proj_in = nn.Linear(caption_channels * text_proj_in_factor, caption_channels, bias=False)
|
||||
self.video_connector = LTX2ConnectorTransformer1d(
|
||||
num_attention_heads=video_connector_num_attention_heads,
|
||||
attention_head_dim=video_connector_attention_head_dim,
|
||||
num_layers=video_connector_num_layers,
|
||||
num_learnable_registers=video_connector_num_learnable_registers,
|
||||
rope_base_seq_len=connector_rope_base_seq_len,
|
||||
rope_theta=rope_theta,
|
||||
rope_double_precision=rope_double_precision,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
)
|
||||
self.audio_connector = LTX2ConnectorTransformer1d(
|
||||
num_attention_heads=audio_connector_num_attention_heads,
|
||||
attention_head_dim=audio_connector_attention_head_dim,
|
||||
num_layers=audio_connector_num_layers,
|
||||
num_learnable_registers=audio_connector_num_learnable_registers,
|
||||
rope_base_seq_len=connector_rope_base_seq_len,
|
||||
rope_theta=rope_theta,
|
||||
rope_double_precision=rope_double_precision,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor, additive_mask: bool = False
|
||||
):
|
||||
# Convert to additive attention mask, if necessary
|
||||
if not additive_mask:
|
||||
text_dtype = text_encoder_hidden_states.dtype
|
||||
attention_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
||||
attention_mask = attention_mask.to(text_dtype) * torch.finfo(text_dtype).max
|
||||
|
||||
text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states)
|
||||
|
||||
video_text_embedding, new_attn_mask = self.video_connector(text_encoder_hidden_states, attention_mask)
|
||||
|
||||
attn_mask = (new_attn_mask < 1e-6).to(torch.int64)
|
||||
attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1)
|
||||
video_text_embedding = video_text_embedding * attn_mask
|
||||
new_attn_mask = attn_mask.squeeze(-1)
|
||||
|
||||
audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, attention_mask)
|
||||
|
||||
return video_text_embedding, audio_text_embedding, new_attn_mask
|
||||
@@ -29,8 +29,8 @@ from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .connectors import LTX2TextConnectors
|
||||
from .pipeline_output import LTX2PipelineOutput
|
||||
from .text_encoder import LTX2AudioVisualTextEncoder
|
||||
from .vocoder import LTX2Vocoder
|
||||
|
||||
|
||||
@@ -192,9 +192,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
tokenizer (`T5TokenizerFast`):
|
||||
Second Tokenizer of class
|
||||
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
||||
connectors ([`LTX2TextConnectors`]):
|
||||
Text connector stack used to adapt text encoder hidden states for the video and audio branches.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae->audio_vae->vocoder"
|
||||
model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
@@ -203,8 +205,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLLTX2Video,
|
||||
audio_vae: AutoencoderKLLTX2Audio,
|
||||
text_encoder: LTX2AudioVisualTextEncoder,
|
||||
text_encoder: Gemma3ForConditionalGeneration,
|
||||
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
|
||||
connectors: LTX2TextConnectors,
|
||||
transformer: LTX2VideoTransformer3DModel,
|
||||
vocoder: LTX2Vocoder,
|
||||
):
|
||||
@@ -215,6 +218,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
audio_vae=audio_vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
connectors=connectors,
|
||||
transformer=transformer,
|
||||
vocoder=vocoder,
|
||||
scheduler=scheduler,
|
||||
@@ -252,6 +256,73 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _pack_text_embeds(
|
||||
text_hidden_states: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
device: Union[str, torch.device],
|
||||
padding_side: str = "left",
|
||||
scale_factor: int = 8,
|
||||
eps: float = 1e-6,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and
|
||||
per-layer in a masked fashion (only over non-padded positions).
|
||||
|
||||
Args:
|
||||
text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
|
||||
Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
|
||||
sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
|
||||
The number of valid (non-padded) tokens for each batch instance.
|
||||
device: (`str` or `torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
padding_side: (`str`, *optional*, defaults to `"left"`):
|
||||
Whether the text tokenizer performs padding on the `"left"` or `"right"`.
|
||||
scale_factor (`int`, *optional*, defaults to `8`):
|
||||
Scaling factor to multiply the normalized hidden states by.
|
||||
eps (`float`, *optional*, defaults to `1e-6`):
|
||||
A small positive value for numerical stability when performing normalization.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
|
||||
Normed and flattened text encoder hidden states.
|
||||
"""
|
||||
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
|
||||
original_dtype = text_hidden_states.dtype
|
||||
|
||||
# Create padding mask
|
||||
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
|
||||
if padding_side == "right":
|
||||
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||
mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len]
|
||||
elif padding_side == "left":
|
||||
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||
start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1]
|
||||
mask = token_indices >= start_indices # [B, T]
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1]
|
||||
|
||||
# Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len)
|
||||
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
|
||||
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
|
||||
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
|
||||
|
||||
# Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len)
|
||||
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
|
||||
# Normalization
|
||||
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
|
||||
normalized_hidden_states = normalized_hidden_states * scale_factor
|
||||
|
||||
# Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.flatten(2)
|
||||
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
|
||||
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
|
||||
return normalized_hidden_states
|
||||
|
||||
def _get_gemma_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
@@ -274,7 +345,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.base_text_encoder.dtype
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
@@ -296,29 +367,34 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
text_input_ids = text_input_ids.to(device)
|
||||
prompt_attention_mask = prompt_attention_mask.to(device)
|
||||
|
||||
prompt_embeds, audio_prompt_embeds, prompt_attention_mask = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=prompt_attention_mask.to(device),
|
||||
text_encoder_outputs = self.text_encoder(
|
||||
input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
|
||||
)
|
||||
text_encoder_hidden_states = text_encoder_outputs.hidden_states
|
||||
text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
|
||||
sequence_lengths = prompt_attention_mask.sum(dim=-1)
|
||||
|
||||
prompt_embeds = self._pack_text_embeds(
|
||||
text_encoder_hidden_states,
|
||||
sequence_lengths,
|
||||
device=device,
|
||||
padding_side=self.tokenizer.padding_side,
|
||||
scale_factor=scale_factor,
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
||||
audio_prompt_embeds = audio_prompt_embeds.to(dtype=dtype)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
_, audio_seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, audio_seq_len, -1)
|
||||
|
||||
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
|
||||
|
||||
return prompt_embeds, audio_prompt_embeds, prompt_attention_mask
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
@@ -327,9 +403,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
audio_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_audio_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 1024,
|
||||
@@ -372,7 +446,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, audio_prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
|
||||
prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
@@ -397,7 +471,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds, negative_audio_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
@@ -406,7 +480,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, audio_prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_audio_prompt_embeds, negative_prompt_attention_mask
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
@@ -668,10 +742,8 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
audio_latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
audio_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_audio_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
decode_timestep: Union[float, List[float]] = 0.0,
|
||||
decode_noise_scale: Optional[Union[float, List[float]]] = None,
|
||||
@@ -732,17 +804,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
audio_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings for audio processing. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated attention mask for text embeddings.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
negative_audio_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings for audio processing. For PixArt-Sigma this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated attention mask for negative text embeddings.
|
||||
decode_timestep (`float`, defaults to `0.0`):
|
||||
@@ -812,10 +878,8 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
# 3. Prepare text embeddings
|
||||
(
|
||||
prompt_embeds,
|
||||
audio_prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_audio_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
@@ -823,9 +887,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
audio_prompt_embeds=audio_prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_audio_prompt_embeds=negative_audio_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
max_sequence_length=max_sequence_length,
|
||||
@@ -833,9 +895,13 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
audio_prompt_embeds = torch.cat([negative_audio_prompt_embeds, audio_prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
|
||||
additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0
|
||||
connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors(
|
||||
prompt_embeds, additive_attention_mask, additive_mask=True
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
@@ -939,11 +1005,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
noise_pred_video, noise_pred_audio = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
audio_hidden_states=audio_latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
audio_encoder_hidden_states=audio_prompt_embeds,
|
||||
encoder_hidden_states=connector_prompt_embeds,
|
||||
audio_encoder_hidden_states=connector_audio_prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
audio_encoder_attention_mask=prompt_attention_mask,
|
||||
encoder_attention_mask=connector_attention_mask,
|
||||
audio_encoder_attention_mask=connector_attention_mask,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
|
||||
@@ -23,14 +23,14 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from .connectors import LTX2TextConnectors
|
||||
from .pipeline_output import LTX2PipelineOutput
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .text_encoder import LTX2AudioVisualTextEncoder
|
||||
from .vocoder import LTX2Vocoder
|
||||
from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video
|
||||
from ...models.transformers import LTX2VideoTransformer3DModel
|
||||
from transformers import GemmaTokenizer, GemmaTokenizerFast
|
||||
from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
|
||||
from ...video_processor import VideoProcessor
|
||||
|
||||
|
||||
@@ -196,7 +196,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
TODO
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae->audio_vae->vocoder"
|
||||
model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
@@ -205,8 +205,9 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLLTX2Video,
|
||||
audio_vae: AutoencoderKLLTX2Audio,
|
||||
text_encoder: LTX2AudioVisualTextEncoder,
|
||||
text_encoder: Gemma3ForConditionalGeneration,
|
||||
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
|
||||
connectors: LTX2TextConnectors,
|
||||
transformer: LTX2VideoTransformer3DModel,
|
||||
vocoder: LTX2Vocoder,
|
||||
):
|
||||
@@ -217,6 +218,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
audio_vae=audio_vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
connectors=connectors,
|
||||
transformer=transformer,
|
||||
vocoder=vocoder,
|
||||
scheduler=scheduler,
|
||||
@@ -254,6 +256,74 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds
|
||||
def _pack_text_embeds(
|
||||
text_hidden_states: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
device: Union[str, torch.device],
|
||||
padding_side: str = "left",
|
||||
scale_factor: int = 8,
|
||||
eps: float = 1e-6,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and
|
||||
per-layer in a masked fashion (only over non-padded positions).
|
||||
|
||||
Args:
|
||||
text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
|
||||
Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
|
||||
sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
|
||||
The number of valid (non-padded) tokens for each batch instance.
|
||||
device: (`str` or `torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
padding_side: (`str`, *optional*, defaults to `"left"`):
|
||||
Whether the text tokenizer performs padding on the `"left"` or `"right"`.
|
||||
scale_factor (`int`, *optional*, defaults to `8`):
|
||||
Scaling factor to multiply the normalized hidden states by.
|
||||
eps (`float`, *optional*, defaults to `1e-6`):
|
||||
A small positive value for numerical stability when performing normalization.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
|
||||
Normed and flattened text encoder hidden states.
|
||||
"""
|
||||
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
|
||||
original_dtype = text_hidden_states.dtype
|
||||
|
||||
# Create padding mask
|
||||
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
|
||||
if padding_side == "right":
|
||||
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||
mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len]
|
||||
elif padding_side == "left":
|
||||
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||
start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1]
|
||||
mask = token_indices >= start_indices # [B, T]
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1]
|
||||
|
||||
# Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len)
|
||||
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
|
||||
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
|
||||
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
|
||||
|
||||
# Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len)
|
||||
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
|
||||
# Normalization
|
||||
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
|
||||
normalized_hidden_states = normalized_hidden_states * scale_factor
|
||||
|
||||
# Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.flatten(2)
|
||||
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
|
||||
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
|
||||
return normalized_hidden_states
|
||||
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds
|
||||
def _get_gemma_prompt_embeds(
|
||||
self,
|
||||
@@ -277,7 +347,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.base_text_encoder.dtype
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
@@ -299,29 +369,34 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
text_input_ids = text_input_ids.to(device)
|
||||
prompt_attention_mask = prompt_attention_mask.to(device)
|
||||
|
||||
prompt_embeds, audio_prompt_embeds, prompt_attention_mask = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=prompt_attention_mask.to(device),
|
||||
text_encoder_outputs = self.text_encoder(
|
||||
input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
|
||||
)
|
||||
text_encoder_hidden_states = text_encoder_outputs.hidden_states
|
||||
text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
|
||||
sequence_lengths = prompt_attention_mask.sum(dim=-1)
|
||||
|
||||
prompt_embeds = self._pack_text_embeds(
|
||||
text_encoder_hidden_states,
|
||||
sequence_lengths,
|
||||
device=device,
|
||||
padding_side=self.tokenizer.padding_side,
|
||||
scale_factor=scale_factor,
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
||||
audio_prompt_embeds = audio_prompt_embeds.to(dtype=dtype)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
_, audio_seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, audio_seq_len, -1)
|
||||
|
||||
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
|
||||
|
||||
return prompt_embeds, audio_prompt_embeds, prompt_attention_mask
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
@@ -331,9 +406,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
audio_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_audio_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 1024,
|
||||
@@ -376,7 +449,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds, audio_prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
|
||||
prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
@@ -401,7 +474,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds, negative_audio_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
@@ -410,7 +483,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, audio_prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_audio_prompt_embeds, negative_prompt_attention_mask
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.check_inputs
|
||||
def check_inputs(
|
||||
@@ -727,10 +800,8 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
audio_latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
audio_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_audio_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
decode_timestep: Union[float, List[float]] = 0.0,
|
||||
decode_noise_scale: Optional[Union[float, List[float]]] = None,
|
||||
@@ -793,17 +864,11 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
audio_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings for audio processing. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Pre-generated attention mask for text embeddings.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
negative_audio_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings for audio processing. For PixArt-Sigma this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated attention mask for negative text embeddings.
|
||||
decode_timestep (`float`, defaults to `0.0`):
|
||||
@@ -873,10 +938,8 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
# 3. Prepare text embeddings
|
||||
(
|
||||
prompt_embeds,
|
||||
audio_prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_audio_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
@@ -884,9 +947,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
audio_prompt_embeds=audio_prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_audio_prompt_embeds=negative_audio_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
max_sequence_length=max_sequence_length,
|
||||
@@ -894,9 +955,13 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
audio_prompt_embeds = torch.cat([negative_audio_prompt_embeds, audio_prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
|
||||
additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0
|
||||
connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors(
|
||||
prompt_embeds, additive_attention_mask, additive_mask=True
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
if latents is None:
|
||||
image = self.video_processor.preprocess(image, height=height, width=width)
|
||||
@@ -1008,12 +1073,12 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoL
|
||||
noise_pred_video, noise_pred_audio = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
audio_hidden_states=audio_latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
audio_encoder_hidden_states=audio_prompt_embeds,
|
||||
encoder_hidden_states=connector_prompt_embeds,
|
||||
audio_encoder_hidden_states=connector_audio_prompt_embeds,
|
||||
timestep=video_timestep,
|
||||
audio_timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
audio_encoder_attention_mask=prompt_attention_mask,
|
||||
encoder_attention_mask=connector_attention_mask,
|
||||
audio_encoder_attention_mask=connector_attention_mask,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
|
||||
@@ -1,625 +0,0 @@
|
||||
# Copyright 2025 The Lightricks team and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoConfig, AutoModel, Gemma3ForConditionalGeneration
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ...models.attention_dispatch import dispatch_attention_fn
|
||||
from ...models.embeddings import get_1d_rotary_pos_embed
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...utils import is_torch_version, logging
|
||||
from ..pipeline_loading_utils import _fetch_class_library_tuple
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def apply_rotary_emb(x: torch.Tensor, freqs: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
||||
cos, sin = freqs
|
||||
x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
return out
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_ltx2.LTX2AudioVideoAttnProcessor
|
||||
class LTX2AudioVideoAttnProcessor:
|
||||
r"""
|
||||
Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0) for the LTX-2.0 model.
|
||||
Compared to the LTX-1.0 model, we allow the RoPE embeddings for the queries and keys to be separate so that we can
|
||||
support audio-to-video (a2v) and video-to-audio (v2a) cross attention.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if is_torch_version("<", "2.0"):
|
||||
raise ValueError(
|
||||
"LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "LTX2Attention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if query_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, query_rotary_emb)
|
||||
key = apply_rotary_emb(key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_ltx2.LTX2Attention
|
||||
class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
|
||||
r"""
|
||||
Attention class for all LTX-2.0 attention layers. Compared to LTX-1.0, this supports specifying the query and key
|
||||
RoPE embeddings separately for audio-to-video (a2v) and video-to-audio (v2a) cross-attention.
|
||||
"""
|
||||
|
||||
_default_processor_cls = LTX2AudioVideoAttnProcessor
|
||||
_available_processors = [LTX2AudioVideoAttnProcessor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
heads: int = 8,
|
||||
kv_heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = True,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
out_bias: bool = True,
|
||||
qk_norm: str = "rms_norm_across_heads",
|
||||
norm_eps: float = 1e-6,
|
||||
norm_elementwise_affine: bool = True,
|
||||
processor=None,
|
||||
):
|
||||
super().__init__()
|
||||
if qk_norm != "rms_norm_across_heads":
|
||||
raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.")
|
||||
|
||||
self.head_dim = dim_head
|
||||
self.inner_dim = dim_head * heads
|
||||
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
|
||||
self.query_dim = query_dim
|
||||
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
self.use_bias = bias
|
||||
self.dropout = dropout
|
||||
self.out_dim = query_dim
|
||||
self.heads = heads
|
||||
|
||||
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
||||
self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
||||
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
||||
self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
||||
self.to_out = torch.nn.ModuleList([])
|
||||
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
self.to_out.append(torch.nn.Dropout(dropout))
|
||||
|
||||
if processor is None:
|
||||
processor = self._default_processor_cls()
|
||||
self.set_processor(processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
query_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
key_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
||||
)
|
||||
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
||||
hidden_states = self.processor(
|
||||
self, hidden_states, encoder_hidden_states, attention_mask, query_rotary_emb, key_rotary_emb, **kwargs
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTX2RotaryPosEmbed1d(nn.Module):
|
||||
"""
|
||||
1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
base_seq_len: int = 4096,
|
||||
theta: float = 10000.0,
|
||||
double_precision: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.base_seq_len = base_seq_len
|
||||
self.theta = theta
|
||||
self.double_precision = double_precision
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch_size: int,
|
||||
pos: int,
|
||||
device: Union[str, torch.device],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Get 1D position ids
|
||||
grid_1d = torch.arange(pos, dtype=torch.float32, device=device)
|
||||
# Get fractional indices relative to self.base_seq_len
|
||||
grid_1d = grid_1d / self.base_seq_len
|
||||
grid = grid_1d.unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len]
|
||||
|
||||
# 2. Calculate 1D RoPE frequencies
|
||||
num_rope_elems = 2 # 1 (because 1D) * 2 (for cos, sin) = 2
|
||||
start = 1.0
|
||||
end = self.theta
|
||||
if self.double_precision:
|
||||
pow_indices = np.power(
|
||||
self.theta,
|
||||
np.linspace(
|
||||
np.log(start) / np.log(self.theta),
|
||||
np.log(end) / np.log(self.theta),
|
||||
self.dim // num_rope_elems,
|
||||
dtype=np.float64,
|
||||
),
|
||||
)
|
||||
freqs = torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32, device=device)
|
||||
else:
|
||||
freqs = self.theta ** torch.linspace(
|
||||
start=math.log(start, self.theta),
|
||||
end=math.log(end, self.theta),
|
||||
steps=self.dim // num_rope_elems,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
freqs = freqs * math.pi / 2.0
|
||||
|
||||
# 3. Matrix-vector outer product between pos ids of shape (batch_size, seq_len) and freqs vector of shape
|
||||
# (self.dim // 2,).
|
||||
freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, seq_len, self.dim // 2]
|
||||
|
||||
# 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim
|
||||
cos_freqs = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_freqs = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
|
||||
if self.dim % num_rope_elems != 0:
|
||||
cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems])
|
||||
sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems])
|
||||
cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
|
||||
sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
|
||||
|
||||
return cos_freqs, sin_freqs
|
||||
|
||||
|
||||
class LTX2TransformerBlock1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.attn1 = LTX2Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
kv_heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
processor=LTX2AudioVideoAttnProcessor(),
|
||||
)
|
||||
|
||||
self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.ff = FeedForward(dim, activation_fn=activation_fn)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
attn_hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, query_rotary_emb=rotary_emb)
|
||||
hidden_states = hidden_states + attn_hidden_states
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
ff_hidden_states = self.ff(norm_hidden_states)
|
||||
hidden_states = hidden_states + ff_hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTX2ConnectorTransformer1d(nn.Module):
|
||||
"""
|
||||
A 1D sequence transformer for modalities such as text.
|
||||
|
||||
In LTX 2.0, this is used to process the text encoder hidden states for each of the video and audio streams.
|
||||
"""
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 30,
|
||||
attention_head_dim: int = 128,
|
||||
num_layers: int = 2,
|
||||
num_learnable_registers: Optional[int] = 128,
|
||||
rope_base_seq_len: int = 4096,
|
||||
rope_theta: float = 10000.0,
|
||||
rope_double_precision: bool = True,
|
||||
eps: float = 1e-6,
|
||||
causal_temporal_positioning: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.causal_temporal_positioning = causal_temporal_positioning
|
||||
|
||||
self.num_learnable_registers = num_learnable_registers
|
||||
self.learnable_registers = None
|
||||
if num_learnable_registers is not None:
|
||||
init_registers = torch.rand(num_learnable_registers, self.inner_dim) * 2.0 - 1.0
|
||||
self.learnable_registers = torch.nn.Parameter(init_registers)
|
||||
|
||||
self.rope = LTX2RotaryPosEmbed1d(
|
||||
self.inner_dim, base_seq_len=rope_base_seq_len, theta=rope_theta, double_precision=rope_double_precision
|
||||
)
|
||||
|
||||
self.transformer_blocks = torch.nn.ModuleList(
|
||||
[
|
||||
LTX2TransformerBlock1d(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = torch.nn.RMSNorm(self.inner_dim, eps=eps, elementwise_affine=False)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# hidden_states shape: [batch_size, seq_len, hidden_dim]
|
||||
# attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len]
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
|
||||
# 1. Replace padding with learned registers, if using
|
||||
if self.learnable_registers is not None:
|
||||
if seq_len % self.num_learnable_registers != 0:
|
||||
raise ValueError(
|
||||
f"The `hidden_states` sequence length {hidden_states.shape[1]} should be divisible by the number"
|
||||
f" of learnable registers {self.num_learnable_registers}"
|
||||
)
|
||||
|
||||
num_register_repeats = seq_len // self.num_learnable_registers
|
||||
registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim]
|
||||
|
||||
binary_attn_mask = (attention_mask >= -9000.0).int()
|
||||
if binary_attn_mask.ndim == 4:
|
||||
binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L]
|
||||
|
||||
hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)]
|
||||
valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded]
|
||||
pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens]
|
||||
padded_hidden_states = [
|
||||
F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths)
|
||||
]
|
||||
padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D]
|
||||
|
||||
flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1]
|
||||
hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers
|
||||
|
||||
# Overwrite attention_mask with an all-zeros mask if using registers.
|
||||
attention_mask = torch.zeros_like(attention_mask)
|
||||
|
||||
# 2. Calculate 1D RoPE positional embeddings
|
||||
rotary_emb = self.rope(batch_size, seq_len, device=hidden_states.device)
|
||||
|
||||
# 3. Run 1D transformer blocks
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(block, hidden_states, attention_mask, rotary_emb)
|
||||
else:
|
||||
hidden_states = block(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
|
||||
return hidden_states, attention_mask
|
||||
|
||||
|
||||
class LTX2AudioVisualTextEncoder(ModelMixin, ConfigMixin):
|
||||
ignore_for_config = ["text_model"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
text_model: Optional[Gemma3ForConditionalGeneration] = None,
|
||||
text_model_id: str = "google/gemma-3-12b-it-qat-q4_0-unquantized",
|
||||
text_encoder_hidden_dim: Optional[int] = 3840,
|
||||
text_proj_in_factor: Optional[int] = 49, # Num layers in text encoder + 1
|
||||
video_connector_num_attention_heads: int = 30,
|
||||
video_connector_attention_head_dim: int = 128,
|
||||
video_connector_num_layers: int = 2,
|
||||
video_connector_num_learnable_registers: int = 128,
|
||||
audio_connector_num_attention_heads: int = 30,
|
||||
audio_connector_attention_head_dim: int = 128,
|
||||
audio_connector_num_layers: int = 2,
|
||||
audio_connector_num_learnable_registers: Optional[int] = 128,
|
||||
rope_base_seq_len: int = 4096,
|
||||
rope_theta: float = 10000.0,
|
||||
rope_double_precision: bool = True,
|
||||
causal_temporal_positioning: bool = False,
|
||||
config_only: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
if text_model is None:
|
||||
self.set_base_text_encoder(text_model_id, config_only=config_only)
|
||||
else:
|
||||
self.base_text_encoder = text_model
|
||||
|
||||
if text_encoder_hidden_dim is None:
|
||||
if hasattr(self.base_text_encoder, "config"):
|
||||
if hasattr(self.base_text_encoder.config, "hidden_size"):
|
||||
text_encoder_hidden_dim = getattr(self.base_text_encoder.config, "hidden_size", None)
|
||||
elif hasattr(self.base_text_encoder.config, "text_config"):
|
||||
text_encoder_hidden_dim = getattr(self.base_text_encoder.config.text_config, "hidden_size", None)
|
||||
if text_encoder_hidden_dim is None:
|
||||
raise ValueError(
|
||||
"`text_encoder_hidden_dim` is `None` and it cannot be inferred, please provide a value for it."
|
||||
)
|
||||
|
||||
if text_proj_in_factor is None:
|
||||
num_layers = None
|
||||
if hasattr(self.base_text_encoder, "config"):
|
||||
if hasattr(self.base_text_encoder.config, "num_hidden_layers"):
|
||||
num_layers = getattr(self.base_text_encoder.config, "num_hidden_layers", None)
|
||||
elif hasattr(self.base_text_encoder.config, "text_config"):
|
||||
num_layers = getattr(self.base_text_encoder.config.text_config, "num_hidden_layers", None)
|
||||
if num_layers is None:
|
||||
raise ValueError(
|
||||
"`text_proj_in_factor` is `None` and it cannot be inferred, please provide a value for it."
|
||||
)
|
||||
text_proj_in_factor = num_layers + 1
|
||||
|
||||
self.text_proj_in = nn.Linear(
|
||||
text_encoder_hidden_dim * text_proj_in_factor, text_encoder_hidden_dim, bias=False
|
||||
)
|
||||
|
||||
self.video_connector = LTX2ConnectorTransformer1d(
|
||||
num_attention_heads=video_connector_num_attention_heads,
|
||||
attention_head_dim=video_connector_attention_head_dim,
|
||||
num_layers=video_connector_num_layers,
|
||||
num_learnable_registers=video_connector_num_learnable_registers,
|
||||
rope_base_seq_len=rope_base_seq_len,
|
||||
rope_theta=rope_theta,
|
||||
rope_double_precision=rope_double_precision,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
)
|
||||
self.audio_connector = LTX2ConnectorTransformer1d(
|
||||
num_attention_heads=audio_connector_num_attention_heads,
|
||||
attention_head_dim=audio_connector_attention_head_dim,
|
||||
num_layers=audio_connector_num_layers,
|
||||
num_learnable_registers=audio_connector_num_learnable_registers,
|
||||
rope_base_seq_len=rope_base_seq_len,
|
||||
rope_theta=rope_theta,
|
||||
rope_double_precision=rope_double_precision,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
)
|
||||
|
||||
def set_base_text_encoder(
|
||||
self, base_text_encoder_id: str = "google/gemma-3-12b-it-qat-q4_0-unquantized", config_only: bool = True
|
||||
):
|
||||
if config_only:
|
||||
base_text_encoder_config = AutoConfig.from_pretrained(base_text_encoder_id)
|
||||
base_text_encoder = AutoModel.from_config(base_text_encoder_config)
|
||||
else:
|
||||
base_text_encoder = AutoModel.from_pretrained(base_text_encoder_id)
|
||||
self.base_text_encoder = base_text_encoder
|
||||
|
||||
@staticmethod
|
||||
def pack_text_embeds(
|
||||
text_hidden_states: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
device: Union[str, torch.device],
|
||||
padding_side: str = "left",
|
||||
scale_factor: int = 8,
|
||||
eps: float = 1e-6,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and
|
||||
per-layer in a masked fashion (only over non-padded positions).
|
||||
|
||||
Args:
|
||||
text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
|
||||
Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
|
||||
sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
|
||||
The number of valid (non-padded) tokens for each batch instance.
|
||||
device: (`str` or `torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
padding_side: (`str`, *optional*, defaults to `"left"`):
|
||||
Whether the text tokenizer performs padding on the `"left"` or `"right"`.
|
||||
scale_factor (`int`, *optional*, defaults to `8`):
|
||||
Scaling factor to multiply the normalized hidden states by.
|
||||
eps (`float`, *optional*, defaults to `1e-6`):
|
||||
A small positive value for numerical stability when performing normalization.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
|
||||
Normed and flattened text encoder hidden states.
|
||||
"""
|
||||
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
|
||||
original_dtype = text_hidden_states.dtype
|
||||
|
||||
# Create padding mask
|
||||
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
|
||||
if padding_side == "right":
|
||||
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||
mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len]
|
||||
elif padding_side == "left":
|
||||
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||
start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1]
|
||||
mask = token_indices >= start_indices # [B, T]
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1]
|
||||
|
||||
# Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len)
|
||||
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
|
||||
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
|
||||
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
|
||||
|
||||
# Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len)
|
||||
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
|
||||
# Normalization
|
||||
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
|
||||
normalized_hidden_states = normalized_hidden_states * scale_factor
|
||||
|
||||
# Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.flatten(2)
|
||||
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
|
||||
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
|
||||
return normalized_hidden_states
|
||||
|
||||
def run_connectors(
|
||||
self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Run LTX 2.0-specific text embedding post-processing logic on top of the base text encoder hidden_states.
|
||||
|
||||
Args:
|
||||
text_encoder_hidden_states (`torch.Tensor`):
|
||||
Text encoder packed hidden_states of shape `(batch_size, seq_len, hidden_dim * (num_layers + 1))`.
|
||||
attention_mask (`torch.Tensor`):
|
||||
Attention mask of shape `(batch_size, seq_len)`.
|
||||
|
||||
Returns:
|
||||
`Tuple(torch.Tensor, torch.Tensor, torch.Tensor)]`:
|
||||
Returns a 3-tuple of tensors where the first element is the video text embeddings of shape
|
||||
`(batch_size, seq_len, hidden_dim)`, the second element is the audio text embeddings of shape
|
||||
`(batch_size, seq_len, hidden_dim)`, and the third element is an attention mask of shape
|
||||
`(batch_size, seq_len)`.
|
||||
"""
|
||||
# Convert to additive attention mask
|
||||
text_dtype = text_encoder_hidden_states.dtype
|
||||
connector_attn_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
||||
connector_attn_mask = connector_attn_mask.to(text_dtype) * torch.finfo(text_dtype).max
|
||||
|
||||
text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states)
|
||||
|
||||
video_text_embedding, new_attn_mask = self.video_connector(
|
||||
text_encoder_hidden_states, connector_attn_mask
|
||||
)
|
||||
|
||||
attn_mask = (new_attn_mask < 1e-6).to(torch.int64)
|
||||
attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1)
|
||||
video_text_embedding = video_text_embedding * attn_mask
|
||||
new_attn_mask = attn_mask.squeeze(-1)
|
||||
|
||||
audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, connector_attn_mask)
|
||||
|
||||
return video_text_embedding, audio_text_embedding, new_attn_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text_input_ids,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
padding_side: str = "left",
|
||||
scale_factor: int = 8,
|
||||
):
|
||||
text_encoder_outputs = self.base_text_encoder(
|
||||
input_ids=text_input_ids, attention_mask=attention_mask, output_hidden_states=True
|
||||
)
|
||||
|
||||
text_encoder_hidden_states = text_encoder_outputs.hidden_states
|
||||
text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
|
||||
sequence_lengths = attention_mask.sum(dim=-1)
|
||||
|
||||
text_encoder_hidden_states = self.pack_text_embeds(
|
||||
text_encoder_hidden_states,
|
||||
sequence_lengths,
|
||||
device=text_encoder_hidden_states.device,
|
||||
padding_side=padding_side,
|
||||
scale_factor=scale_factor,
|
||||
)
|
||||
|
||||
video_text_embedding, audio_text_embedding, new_attn_mask = self.run_connectors(
|
||||
text_encoder_hidden_states, attention_mask
|
||||
)
|
||||
|
||||
return video_text_embedding, audio_text_embedding, new_attn_mask
|
||||
@@ -99,6 +99,7 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"num_layers": 2,
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"caption_channels": 16,
|
||||
"rope_double_precision": False,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
Reference in New Issue
Block a user